Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
......@@ -63,12 +63,10 @@ class DenseGraphConv(nn.Module):
--------
`GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
bias=True,
activation=None):
def __init__(
self, in_feats, out_feats, norm="both", bias=True, activation=None
):
super(DenseGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -77,7 +75,7 @@ class DenseGraphConv(nn.Module):
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
self._activation = activation
......@@ -114,7 +112,7 @@ class DenseGraphConv(nn.Module):
dst_degrees = adj.sum(dim=1).clamp(min=1)
feat_src = feat
if self._norm == 'both':
if self._norm == "both":
norm_src = th.pow(src_degrees, -0.5)
shp = norm_src.shape + (1,) * (feat.dim() - 1)
norm_src = th.reshape(norm_src, shp).to(feat.device)
......@@ -129,10 +127,10 @@ class DenseGraphConv(nn.Module):
rst = adj @ feat_src
rst = th.matmul(rst, self.weight)
if self._norm != 'none':
if self._norm == 'both':
if self._norm != "none":
if self._norm == "both":
norm_dst = th.pow(dst_degrees, -0.5)
else: # right
else: # right
norm_dst = 1.0 / dst_degrees
shp = norm_dst.shape + (1,) * (feat.dim() - 1)
norm_dst = th.reshape(norm_dst, shp).to(feat.device)
......
"""Torch Module for DenseSAGEConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn
from ....utils import check_eq_shape
......@@ -56,13 +57,16 @@ class DenseSAGEConv(nn.Module):
--------
`SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__
"""
def __init__(self,
in_feats,
out_feats,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
def __init__(
self,
in_feats,
out_feats,
feat_drop=0.0,
bias=True,
norm=None,
activation=None,
):
super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -83,7 +87,7 @@ class DenseSAGEConv(nn.Module):
-----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, adj, feat):
......
"""Torch Module for Directional Graph Networks Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
"""directional average aggregation"""
h_mod = torch.mul(h, (
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30)).unsqueeze(-1))
h_mod = torch.mul(
h,
(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1),
)
return torch.sum(h_mod, dim=1)
def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):
"""directional derivative aggregation"""
eig_w = ((
eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30
eig_w = (
(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1)
h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
for k in range(1, 4):
AGGREGATORS[f'dir{k}-av'] = partial(aggregate_dir_av, eig_idx=k-1)
AGGREGATORS[f'dir{k}-dx'] = partial(aggregate_dir_dx, eig_idx=k-1)
AGGREGATORS[f"dir{k}-av"] = partial(aggregate_dir_av, eig_idx=k - 1)
AGGREGATORS[f"dir{k}-dx"] = partial(aggregate_dir_dx, eig_idx=k - 1)
class DGNConvTower(PNAConvTower):
"""A single DGN tower with modified reduce function"""
def message(self, edges):
"""message function for DGN layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}
f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {
"msg": self.M(f),
"eig_s": edges.src["eig"],
"eig_d": edges.dst["eig"],
}
def reduce_func(self, nodes):
"""reduce function for DGN layer"""
h_in = nodes.data['h']
eig_s = nodes.mailbox['eig_s']
eig_d = nodes.mailbox['eig_d']
msg = nodes.mailbox['msg']
h_in = nodes.data["h"]
eig_s = nodes.mailbox["eig_s"]
eig_d = nodes.mailbox["eig_d"]
msg = nodes.mailbox["msg"]
degree = msg.size(1)
h = []
for agg in self.aggregators:
if agg.startswith('dir'):
if agg.endswith('av'):
if agg.startswith("dir"):
if agg.endswith("av"):
h.append(AGGREGATORS[agg](msg, eig_s, eig_d))
else:
h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))
else:
h.append(AGGREGATORS[agg](msg))
h = torch.cat(h, dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
h = torch.cat(
[
SCALERS[scaler](h, D=degree, delta=self.delta)
if scaler != "identity"
else h
for scaler in self.scalers
],
dim=1,
)
return {"h_neigh": h}
class DGNConv(PNAConv):
r"""Directional Graph Network Layer from `Directional Graph Networks
......@@ -154,24 +187,49 @@ class DGNConv(PNAConv):
>>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat, eig_vec=eig)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(DGNConv, self).__init__(
in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual
in_size,
out_size,
aggregators,
scalers,
delta,
dropout,
num_towers,
edge_feat_size,
residual,
)
self.towers = nn.ModuleList([
DGNConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.towers = nn.ModuleList(
[
DGNConvTower(
self.tower_in_size,
self.tower_out_size,
aggregators,
scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.use_eig_vec = False
for aggr in aggregators:
if aggr.startswith('dir'):
if aggr.startswith("dir"):
self.use_eig_vec = True
break
......@@ -203,5 +261,5 @@ class DGNConv(PNAConv):
"""
with graph.local_scope():
if self.use_eig_vec:
graph.ndata['eig'] = eig_vec
graph.ndata["eig"] = eig_vec
return super().forward(graph, node_feat, edge_feat)
......@@ -2,6 +2,7 @@
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from .... import function as fn
......@@ -47,6 +48,7 @@ class EGNNConv(nn.Module):
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
"""
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
......@@ -62,21 +64,21 @@ class EGNNConv(nn.Module):
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
act_fn,
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
nn.Linear(hidden_size, out_size),
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def message(self, edges):
......@@ -84,16 +86,23 @@ class EGNNConv(nn.Module):
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
[
edges.src["h"],
edges.dst["h"],
edges.data["radial"],
edges.data["a"],
],
dim=-1,
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1
)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"]
return {'msg_x': msg_x, 'msg_h': msg_h}
return {"msg_x": msg_x, "msg_h": msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r"""
......@@ -126,27 +135,29 @@ class EGNNConv(nn.Module):
"""
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
graph.ndata["h"] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
graph.ndata["x"] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.edata["a"] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
graph.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
graph.edata["radial"] = (
graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1)
)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.edata["x_diff"] = graph.edata["x_diff"] / (
graph.edata["radial"].sqrt() + 1e-30
)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh"))
graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh"))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"]
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))
x = coord_feat + x_neigh
return h, x
......@@ -58,12 +58,7 @@ class GatedGraphConv(nn.Module):
0.1342, 0.0425]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):
super(GatedGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -87,7 +82,7 @@ class GatedGraphConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
self.gru.reset_parameters()
for linear in self.linears:
init.xavier_normal_(linear.weight, gain=gain)
......@@ -134,36 +129,44 @@ class GatedGraphConv(nn.Module):
is the output feature size.
"""
with graph.local_scope():
assert graph.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \
assert graph.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument"
)
if self._n_etypes != 1:
assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \
"edge type indices out of range [0, {})".format(
self._n_etypes)
assert (
etypes.min() >= 0 and etypes.max() < self._n_etypes
), "edge type indices out of range [0, {})".format(
self._n_etypes
)
zero_pad = feat.new_zeros(
(feat.shape[0], self._out_feats - feat.shape[1]))
(feat.shape[0], self._out_feats - feat.shape[1])
)
feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps):
if self._n_etypes == 1 and etypes is None:
# Fast path when graph has only one edge type
graph.ndata['h'] = self.linears[0](feat)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.ndata["h"] = self.linears[0](feat)
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
else:
graph.ndata['h'] = feat
graph.ndata["h"] = feat
for i in range(self._n_etypes):
eids = th.nonzero(
etypes == i, as_tuple=False).view(-1).type(graph.idtype)
eids = (
th.nonzero(etypes == i, as_tuple=False)
.view(-1)
.type(graph.idtype)
)
if len(eids) > 0:
graph.apply_edges(
lambda edges: {
'W_e*h': self.linears[i](edges.src['h'])},
eids
"W_e*h": self.linears[i](edges.src["h"])
},
eids,
)
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.update_all(fn.copy_e("W_e*h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
feat = self.gru(a, feat)
return feat
......@@ -4,10 +4,11 @@ import torch as th
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235
class GATv2Conv(nn.Module):
......@@ -134,18 +135,21 @@ class GATv2Conv(nn.Module):
[-1.1850, 0.1123],
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False):
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False,
):
super(GATv2Conv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -153,17 +157,21 @@ class GATv2Conv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=bias)
self._in_dst_feats, out_feats * num_heads, bias=bias
)
else:
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
if share_weights:
self.fc_dst = self.fc_src
else:
self.fc_dst = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
......@@ -171,11 +179,12 @@ class GATv2Conv(nn.Module):
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias)
self._in_dst_feats, num_heads * out_feats, bias=bias
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
self.activation = activation
self.share_weights = share_weights
self.bias = bias
......@@ -192,7 +201,7 @@ class GATv2Conv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
if self.bias:
nn.init.constant_(self.fc_src.bias, 0)
......@@ -256,53 +265,70 @@ class GATv2Conv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if self.share_weights:
feat_dst = feat_src
else:
feat_dst = self.fc_dst(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
graph.dstdata.update({'er': feat_dst})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
graph.srcdata.update(
{"el": feat_src}
) # (num_src_edge, num_heads, out_dim)
graph.dstdata.update({"er": feat_dst})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(
graph.edata.pop("e")
) # (num_src_edge, num_heads, out_dim)
e = (
(e * self.attn).sum(dim=-1).unsqueeze(dim=2)
) # (num_edge, num_heads, 1)
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
graph.edata["a"] = self.attn_drop(
edge_softmax(graph, e)
) # (num_edge, num_heads)
# message passing
graph.update_all(fn.u_mul_e('el', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("el", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
......@@ -104,15 +104,17 @@ class GCN2Conv(nn.Module):
"""
def __init__(self,
in_feats,
layer,
alpha=0.1,
lambda_=1,
project_initial_features=True,
allow_zero_in_degree=False,
bias=True,
activation=None):
def __init__(
self,
in_feats,
layer,
alpha=0.1,
lambda_=1,
project_initial_features=True,
allow_zero_in_degree=False,
bias=True,
activation=None,
):
super().__init__()
self._in_feats = in_feats
......@@ -131,7 +133,8 @@ class GCN2Conv(nn.Module):
self.register_parameter("weight2", None)
else:
self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats))
th.Tensor(self._in_feats, self._in_feats)
)
if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats))
......@@ -233,7 +236,7 @@ class GCN2Conv(nn.Module):
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight)
edge_weight = EdgeWeightNorm("both")(graph, edge_weight)
if edge_weight is None:
feat = feat * norm
......@@ -255,14 +258,26 @@ class GCN2Conv(nn.Module):
if self._project_initial_features:
rst = feat.add_(feat_0)
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
else:
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
rst += th.addmm(
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
feat_0,
feat_0,
self.weight2,
beta=(1 - self.beta),
alpha=self.beta,
)
if self._bias:
......
......@@ -7,6 +7,7 @@ from torch import nn
from .... import function as fn
from ....utils import expand_as_pair
class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
......@@ -45,21 +46,19 @@ class GINEConv(nn.Module):
>>> print(res.shape)
torch.Size([4, 20])
"""
def __init__(self,
apply_func=None,
init_eps=0,
learn_eps=False):
def __init__(self, apply_func=None, init_eps=0, learn_eps=False):
super(GINEConv, self).__init__()
self.apply_func = apply_func
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self.register_buffer("eps", th.FloatTensor([init_eps]))
def message(self, edges):
r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])}
return {"m": F.relu(edges.src["hn"] + edges.data["he"])}
def forward(self, graph, node_feat, edge_feat):
r"""Forward computation.
......@@ -89,10 +88,10 @@ class GINEConv(nn.Module):
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src
graph.edata['he'] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
graph.srcdata["hn"] = feat_src
graph.edata["he"] = edge_feat
graph.update_all(self.message, fn.sum("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd"""
@staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn
......@@ -40,19 +43,25 @@ class InvertibleCheckpoint(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible")
raise RuntimeError(
"InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible"
)
# retrieve input and output tensor nodes
if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \
for more than once.")
raise RuntimeError(
"Trying to perform backward on the InvertibleCheckpoint \
for more than once."
)
inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop()
# reconstruct input node features
with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:]))
inputs_inverted = ctx.fn_inverse(
*((inputs[0], outputs) + inputs[2:])
)
# clear memory of outputs
outputs.storage().resize_(0)
......@@ -72,11 +81,16 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
filtered_detached_inputs = tuple(
filter(
lambda x: getattr(x, "requires_grad", False), detached_inputs
)
)
gradients = torch.autograd.grad(
outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs,
)
input_gradients = []
i = 0
......@@ -87,7 +101,7 @@ class InvertibleCheckpoint(torch.autograd.Function):
else:
input_gradients.append(None)
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]
gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]
return (None, None, None) + gradients
......@@ -157,6 +171,7 @@ class GroupRevRes(nn.Module):
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
"""
def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList()
......@@ -173,7 +188,9 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:])
......@@ -192,13 +209,15 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
xs = []
for i in range(self.groups-1, -1, -1):
for i in range(self.groups - 1, -1, -1):
if i != 0:
y_in = ys[i-1]
y_in = ys[i - 1]
else:
y_in = sum(xs)
......@@ -232,6 +251,7 @@ class GroupRevRes(nn.Module):
self._forward,
self._inverse,
len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad])))
*(args + tuple([p for p in self.parameters() if p.requires_grad]))
)
return y
"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn
......@@ -8,6 +9,7 @@ from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax
class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer
<https://arxiv.org/abs/2003.01332>`__
......@@ -65,14 +67,17 @@ class HGTConv(nn.Module):
Examples
--------
"""
def __init__(self,
in_size,
head_size,
num_heads,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False):
def __init__(
self,
in_size,
head_size,
num_heads,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False,
):
super().__init__()
self.in_size = in_size
self.head_size = head_size
......@@ -83,20 +88,33 @@ class HGTConv(nn.Module):
self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes)
self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes))
for i in range(num_heads)])
self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.linear_a = TypedLinear(
head_size * num_heads, head_size * num_heads, num_ntypes
)
self.relation_pri = nn.ParameterList(
[nn.Parameter(torch.ones(num_etypes)) for i in range(num_heads)]
)
self.relation_att = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.relation_msg = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout)
if use_norm:
self.norm = nn.LayerNorm(head_size * num_heads)
if in_size != head_size * num_heads:
self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads))
self.residual_w = nn.Parameter(
torch.Tensor(in_size, head_size * num_heads)
)
nn.init.xavier_uniform_(self.residual_w)
def forward(self, g, x, ntype, etype, *, presorted=False):
......@@ -125,17 +143,25 @@ class HGTConv(nn.Module):
"""
self.presorted = presorted
with g.local_scope():
k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
g.srcdata['k'] = k
g.dstdata['q'] = q
g.srcdata['v'] = v
g.edata['etype'] = etype
k = self.linear_k(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
q = self.linear_q(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
v = self.linear_v(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
g.srcdata["k"] = k
g.dstdata["q"] = q
g.srcdata["v"] = v
g.edata["etype"] = etype
g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h'))
h = g.dstdata['h'].view(-1, self.num_heads * self.head_size)
g.edata["m"] = g.edata["m"] * edge_softmax(
g, g.edata["a"]
).unsqueeze(-1)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "h"))
h = g.dstdata["h"].view(-1, self.num_heads * self.head_size)
# target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
......@@ -150,12 +176,16 @@ class HGTConv(nn.Module):
def message(self, edges):
"""Message function."""
a, m = [], []
etype = edges.data['etype']
k = torch.unbind(edges.src['k'], dim=1)
q = torch.unbind(edges.dst['q'], dim=1)
v = torch.unbind(edges.src['v'], dim=1)
etype = edges.data["etype"]
k = torch.unbind(edges.src["k"], dim=1)
q = torch.unbind(edges.dst["q"], dim=1)
v = torch.unbind(edges.src["v"], dim=1)
for i in range(self.num_heads):
kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O)
a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,)
m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O)
return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)}
a.append(
(kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d
) # (E,)
m.append(
self.relation_msg[i](v[i], etype, self.presorted)
) # (E, O)
return {"a": torch.stack(a, dim=1), "m": torch.stack(m, dim=1)}
......@@ -4,26 +4,32 @@ import numpy as np
import torch
import torch.nn as nn
def aggregate_mean(h):
"""mean aggregation"""
return torch.mean(h, dim=1)
def aggregate_max(h):
"""max aggregation"""
return torch.max(h, dim=1)[0]
def aggregate_min(h):
"""min aggregation"""
return torch.min(h, dim=1)[0]
def aggregate_sum(h):
"""sum aggregation"""
return torch.sum(h, dim=1)
def aggregate_std(h):
"""standard deviation aggregation"""
return torch.sqrt(aggregate_var(h) + 1e-30)
def aggregate_var(h):
"""variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1)
......@@ -31,52 +37,76 @@ def aggregate_var(h):
var = torch.relu(h_mean_squares - h_mean * h_mean)
return var
def _aggregate_moment(h, n):
"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
h_mean = torch.mean(h, dim=1, keepdim=True)
h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1.0 / n)
return rooted_h_n
def aggregate_moment_3(h):
"""moment aggregation with n=3"""
return _aggregate_moment(h, n=3)
def aggregate_moment_4(h):
"""moment aggregation with n=4"""
return _aggregate_moment(h, n=4)
def aggregate_moment_5(h):
"""moment aggregation with n=5"""
return _aggregate_moment(h, n=5)
def scale_identity(h):
"""identity scaling (no scaling operation)"""
return h
def scale_amplification(h, D, delta):
"""amplification scaling"""
return h * (np.log(D + 1) / delta)
def scale_attenuation(h, D, delta):
"""attenuation scaling"""
return h * (delta / np.log(D + 1))
AGGREGATORS = {
'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3,
'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5
"mean": aggregate_mean,
"sum": aggregate_sum,
"max": aggregate_max,
"min": aggregate_min,
"std": aggregate_std,
"var": aggregate_var,
"moment3": aggregate_moment_3,
"moment4": aggregate_moment_4,
"moment5": aggregate_moment_5,
}
SCALERS = {
'identity': scale_identity,
'amplification': scale_amplification,
'attenuation': scale_attenuation
"identity": scale_identity,
"amplification": scale_amplification,
"attenuation": scale_attenuation,
}
class PNAConvTower(nn.Module):
"""A single PNA tower in PNA layers"""
def __init__(self, in_size, out_size, aggregators, scalers,
delta, dropout=0., edge_feat_size=0):
def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
edge_feat_size=0,
):
super(PNAConvTower, self).__init__()
self.in_size = in_size
self.out_size = out_size
......@@ -86,50 +116,63 @@ class PNAConvTower(nn.Module):
self.edge_feat_size = edge_feat_size
self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size)
self.U = nn.Linear(
(len(aggregators) * len(scalers) + 1) * in_size, out_size
)
self.dropout = nn.Dropout(dropout)
self.batchnorm = nn.BatchNorm1d(out_size)
def reduce_func(self, nodes):
"""reduce function for PNA layer:
tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg']
msg = nodes.mailbox["msg"]
degree = msg.size(1)
h = torch.cat([AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
h = torch.cat(
[AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1
)
h = torch.cat(
[
SCALERS[scaler](h, D=degree, delta=self.delta)
if scaler != "identity"
else h
for scaler in self.scalers
],
dim=1,
)
return {"h_neigh": h}
def message(self, edges):
"""message function for PNA layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f)}
f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {"msg": self.M(f)}
def forward(self, graph, node_feat, edge_feat=None):
"""compute the forward pass of a single tower in PNA convolution layer"""
# calculate graph normalization factors
snorm_n = torch.cat(
[torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()],
dim=0
[
torch.ones(N, 1).to(node_feat) / N
for N in graph.batch_num_nodes()
],
dim=0,
).sqrt()
with graph.local_scope():
graph.ndata['h'] = node_feat
graph.ndata["h"] = node_feat
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.edata["a"] = edge_feat
graph.update_all(self.message, self.reduce_func)
h = self.U(
torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
)
h = self.U(torch.cat([node_feat, graph.ndata["h_neigh"]], dim=-1))
h = h * snorm_n
return self.dropout(self.batchnorm(h))
class PNAConv(nn.Module):
r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation
for Graph Nets <https://arxiv.org/abs/2004.05718>`__
......@@ -210,14 +253,29 @@ class PNAConv(nn.Module):
>>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(PNAConv, self).__init__()
self.in_size = in_size
self.out_size = out_size
assert in_size % num_towers == 0, 'in_size must be divisible by num_towers'
assert out_size % num_towers == 0, 'out_size must be divisible by num_towers'
assert (
in_size % num_towers == 0
), "in_size must be divisible by num_towers"
assert (
out_size % num_towers == 0
), "out_size must be divisible by num_towers"
self.tower_in_size = in_size // num_towers
self.tower_out_size = out_size // num_towers
self.edge_feat_size = edge_feat_size
......@@ -225,17 +283,23 @@ class PNAConv(nn.Module):
if self.in_size != self.out_size:
self.residual = False
self.towers = nn.ModuleList([
PNAConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.towers = nn.ModuleList(
[
PNAConvTower(
self.tower_in_size,
self.tower_out_size,
aggregators,
scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.mixing_layer = nn.Sequential(
nn.Linear(out_size, out_size),
nn.LeakyReLU()
nn.Linear(out_size, out_size), nn.LeakyReLU()
)
def forward(self, graph, node_feat, edge_feat=None):
......@@ -261,14 +325,20 @@ class PNAConv(nn.Module):
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
should be the same as out_size.
"""
h_cat = torch.cat([
tower(
graph,
node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size],
edge_feat
)
for ti, tower in enumerate(self.towers)
], dim=1)
h_cat = torch.cat(
[
tower(
graph,
node_feat[
:,
ti * self.tower_in_size : (ti + 1) * self.tower_in_size,
],
edge_feat,
)
for ti, tower in enumerate(self.towers)
],
dim=1,
)
h_out = self.mixing_layer(h_cat)
# add residual connection
if self.residual:
......
......@@ -6,6 +6,7 @@ from torch import nn
from .... import function as fn
from ..linear import TypedLinear
class RelGraphConv(nn.Module):
r"""Relational graph convolution layer from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
......@@ -94,21 +95,26 @@ class RelGraphConv(nn.Module):
[-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer=None,
num_bases=None,
bias=True,
activation=None,
self_loop=True,
dropout=0.0,
layer_norm=False):
def __init__(
self,
in_feat,
out_feat,
num_rels,
regularizer=None,
num_bases=None,
bias=True,
activation=None,
self_loop=True,
dropout=0.0,
layer_norm=False,
):
super().__init__()
if regularizer is not None and num_bases is None:
num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.linear_r = TypedLinear(
in_feat, out_feat, num_rels, regularizer, num_bases
)
self.bias = bias
self.activation = activation
self.self_loop = self_loop
......@@ -123,21 +129,25 @@ class RelGraphConv(nn.Module):
# the module only about graph convolution.
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
self.layer_norm_weight = nn.LayerNorm(
out_feat, elementwise_affine=True
)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout)
def message(self, edges):
"""Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted)
if 'norm' in edges.data:
m = m * edges.data['norm']
return {'m' : m}
m = self.linear_r(edges.src["h"], edges.data["etype"], self.presorted)
if "norm" in edges.data:
m = m * edges.data["norm"]
return {"m": m}
def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Forward computation.
......@@ -165,20 +175,20 @@ class RelGraphConv(nn.Module):
"""
self.presorted = presorted
with g.local_scope():
g.srcdata['h'] = feat
g.srcdata["h"] = feat
if norm is not None:
g.edata['norm'] = norm
g.edata['etype'] = etypes
g.edata["norm"] = norm
g.edata["etype"] = etypes
# message passing
g.update_all(self.message, fn.sum('m', 'h'))
g.update_all(self.message, fn.sum("m", "h"))
# apply bias and activation
h = g.dstdata['h']
h = g.dstdata["h"]
if self.layer_norm:
h = self.layer_norm_weight(h)
if self.bias:
h = h + self.h_bias
if self.self_loop:
h = h + feat[:g.num_dst_nodes()] @ self.loop_weight
h = h + feat[: g.num_dst_nodes()] @ self.loop_weight
if self.activation:
h = self.activation(h)
h = self.dropout(h)
......
......@@ -82,14 +82,16 @@ class SGConv(nn.Module):
[-1.9441, -0.9343]], grad_fn=<AddmmBackward>)
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None,
allow_zero_in_degree=False,
):
super(SGConv, self).__init__()
self.fc = nn.Linear(in_feats, out_feats, bias=bias)
self._cached = cached
......@@ -170,20 +172,23 @@ class SGConv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
graph.edata["_edge_weight"] = EdgeWeightNorm("both")(
graph, edge_weight
)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
if self._cached_h is not None:
......@@ -198,10 +203,9 @@ class SGConv(nn.Module):
for _ in range(self._k):
if edge_weight is None:
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(msg_func,
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
graph.ndata["h"] = feat
graph.update_all(msg_func, fn.sum("m", "h"))
feat = graph.ndata.pop("h")
if edge_weight is None:
feat = feat * norm
......
......@@ -57,13 +57,14 @@ class TAGConv(nn.Module):
[ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
"""
def __init__(self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None,
):
def __init__(
self,
in_feats,
out_feats,
k=2,
bias=True,
activation=None,
):
super(TAGConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -84,7 +85,7 @@ class TAGConv(nn.Module):
----
The model parameters are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, graph, feat, edge_weight=None):
......@@ -114,7 +115,7 @@ class TAGConv(nn.Module):
is size of output feature.
"""
with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous'
assert graph.is_homogeneous, "Graph is not homogeneous"
if edge_weight is None:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
......@@ -122,8 +123,9 @@ class TAGConv(nn.Module):
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
graph.edata["_edge_weight"] = EdgeWeightNorm("both")(
graph, edge_weight
)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
# D-1/2 A D -1/2 X
fstack = [feat]
......@@ -132,11 +134,10 @@ class TAGConv(nn.Module):
rst = fstack[-1] * norm
else:
rst = fstack[-1]
graph.ndata['h'] = rst
graph.ndata["h"] = rst
graph.update_all(msg_func,
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
graph.update_all(msg_func, fn.sum(msg="m", out="h"))
rst = graph.ndata["h"]
if edge_weight is None:
rst = rst * norm
fstack.append(rst)
......
......@@ -4,8 +4,10 @@
import torch as tc
import torch.nn as nn
import torch.nn.functional as F
from .... import function as fn
class TWIRLSConv(nn.Module):
r"""Convolution together with iteratively reweighting least squre from
`Graph Neural Networks Inspired by Classical Iterative Algorithms
......@@ -74,27 +76,28 @@ class TWIRLSConv(nn.Module):
torch.Size([6, 2])
"""
def __init__(self,
input_d,
output_d,
hidden_d,
prop_step,
num_mlp_before=1,
num_mlp_after=1,
norm='none',
precond=True,
alp=0,
lam=1,
attention=False,
tau=0.2,
T=-1,
p=1,
use_eta=False,
attn_bef=False,
dropout=0.0,
attn_dropout=0.0,
inp_dropout=0.0,
):
def __init__(
self,
input_d,
output_d,
hidden_d,
prop_step,
num_mlp_before=1,
num_mlp_after=1,
norm="none",
precond=True,
alp=0,
lam=1,
attention=False,
tau=0.2,
T=-1,
p=1,
use_eta=False,
attn_bef=False,
dropout=0.0,
attn_dropout=0.0,
inp_dropout=0.0,
):
super().__init__()
self.input_d = input_d
......@@ -123,7 +126,10 @@ class TWIRLSConv(nn.Module):
# whether we can cache unfolding result
self.cacheable = (
not self.attention) and self.num_mlp_before == 0 and self.inp_dropout <= 0
(not self.attention)
and self.num_mlp_before == 0
and self.inp_dropout <= 0
)
if self.cacheable:
self.cached_unfolding = None
......@@ -136,20 +142,42 @@ class TWIRLSConv(nn.Module):
self.size_bef_unf = self.output_d # as the output of mlp_bef
# ----- computational modules -----
self.mlp_bef = MLP(self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before,
self.dropout, self.norm, init_activate=False)
self.unfolding = TWIRLSUnfoldingAndAttention(self.hidden_d, self.alp, self.lam,
self.prop_step, self.attn_aft, self.tau,
self.T, self.p, self.use_eta, self.init_att,
self.attn_dropout, self.precond)
self.mlp_bef = MLP(
self.input_d,
self.hidden_d,
self.size_bef_unf,
self.num_mlp_before,
self.dropout,
self.norm,
init_activate=False,
)
self.unfolding = TWIRLSUnfoldingAndAttention(
self.hidden_d,
self.alp,
self.lam,
self.prop_step,
self.attn_aft,
self.tau,
self.T,
self.p,
self.use_eta,
self.init_att,
self.attn_dropout,
self.precond,
)
# if there are really transformations before unfolding, then do init_activate in mlp_aft
self.mlp_aft = MLP(self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after,
self.dropout, self.norm,
init_activate=(self.num_mlp_before > 0) and (
self.num_mlp_after > 0)
)
self.mlp_aft = MLP(
self.size_aft_unf,
self.hidden_d,
self.output_d,
self.num_mlp_after,
self.dropout,
self.norm,
init_activate=(self.num_mlp_before > 0)
and (self.num_mlp_after > 0),
)
def forward(self, graph, feat):
r"""
......@@ -212,8 +240,7 @@ class Propagate(nn.Module):
super().__init__()
def _prop(self, graph, Y, lam):
"""propagation part.
"""
"""propagation part."""
Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)
Y = AX(graph, Y)
Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)
......@@ -245,8 +272,11 @@ class Propagate(nn.Module):
Propagated feature. :math:`Z^{(k+1)}` in eq.28 in the paper.
"""
return (1 - alp) * Y + alp * lam * self._prop(graph, Y, lam) \
return (
(1 - alp) * Y
+ alp * lam * self._prop(graph, Y, lam)
+ alp * D_power_bias_X(graph, X, -1, lam, 1 - lam)
)
class PropagateNoPrecond(nn.Module):
......@@ -287,7 +317,11 @@ class PropagateNoPrecond(nn.Module):
Propagated feature. :math:`Y^{(k+1)}` in eq.30 in the paper.
"""
return (1 - alp * lam - alp) * Y + alp * lam * normalized_AX(graph, Y) + alp * X
return (
(1 - alp * lam - alp) * Y
+ alp * lam * normalized_AX(graph, Y)
+ alp * X
)
class Attention(nn.Module):
......@@ -372,7 +406,7 @@ class Attention(nn.Module):
# computing edge distance
graph.srcdata["h"] = Y
graph.srcdata["h_norm"] = (Y ** 2).sum(-1)
graph.srcdata["h_norm"] = (Y**2).sum(-1)
graph.apply_edges(fn.u_dot_v("h", "h", "dot_"))
graph.apply_edges(fn.u_add_v("h_norm", "h_norm", "norm_"))
graph.edata["dot_"] = graph.edata["dot_"].view(-1)
......@@ -390,7 +424,8 @@ class Attention(nn.Module):
# FIXME: consider if there is a better way
if self.attn_dropout > 0:
graph.edata["w"] = F.dropout(
graph.edata["w"], self.attn_dropout, training=self.training)
graph.edata["w"], self.attn_dropout, training=self.training
)
return graph
......@@ -410,7 +445,8 @@ def AX(graph, X):
graph.srcdata["h"] = X
graph.update_all(
fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"),
fn.u_mul_e("h", "w", "m"),
fn.sum("m", "h"),
)
Y = graph.dstdata["h"]
......@@ -491,20 +527,21 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
"""
def __init__(self,
d,
alp,
lam,
prop_step,
attn_aft=-1,
tau=0.2,
T=-1,
p=1,
use_eta=False,
init_att=False,
attn_dropout=0,
precond=True,
):
def __init__(
self,
d,
alp,
lam,
prop_step,
attn_aft=-1,
tau=0.2,
T=-1,
p=1,
use_eta=False,
init_att=False,
attn_dropout=0,
precond=True,
):
super().__init__()
......@@ -520,12 +557,15 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
prop_method = Propagate if precond else PropagateNoPrecond
self.prop_layers = nn.ModuleList(
[prop_method() for _ in range(prop_step)])
self.init_attn = Attention(
tau, T, p, attn_dropout) if self.init_att else None
self.attn_layer = Attention(
tau, T, p, attn_dropout) if self.attn_aft >= 0 else None
[prop_method() for _ in range(prop_step)]
)
self.init_attn = (
Attention(tau, T, p, attn_dropout) if self.init_att else None
)
self.attn_layer = (
Attention(tau, T, p, attn_dropout) if self.attn_aft >= 0 else None
)
self.etas = nn.Parameter(tc.ones(d)) if self.use_eta else None
def forward(self, g, X):
......@@ -593,7 +633,16 @@ class MLP(nn.Module):
"""
def __init__(self, input_d, hidden_d, output_d, num_layers, dropout, norm, init_activate):
def __init__(
self,
input_d,
hidden_d,
output_d,
num_layers,
dropout,
norm,
init_activate,
):
super().__init__()
self.init_activate = init_activate
......@@ -611,13 +660,15 @@ class MLP(nn.Module):
self.layers.append(nn.Linear(hidden_d, output_d))
# how many norm layers we have
self.norm_cnt = num_layers-1+int(init_activate)
self.norm_cnt = num_layers - 1 + int(init_activate)
if norm == "batch":
self.norms = nn.ModuleList(
[nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)])
[nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)]
)
elif norm == "layer":
self.norms = nn.ModuleList(
[nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)])
[nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)]
)
self.reset_params()
......
......@@ -3,4 +3,4 @@
from .gnnexplainer import GNNExplainer
__all__ = ['GNNExplainer']
__all__ = ["GNNExplainer"]
"""Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn
from ...transforms import knn_graph, segmented_knn_graph, radius_graph
from ...transforms import knn_graph, radius_graph, segmented_knn_graph
def pairwise_squared_distance(x):
'''
"""
x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points)
'''
"""
x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
......@@ -58,13 +60,19 @@ class KNNGraph(nn.Module):
(tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
"""
def __init__(self, k):
super(KNNGraph, self).__init__()
self.k = k
#pylint: disable=invalid-name
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
# pylint: disable=invalid-name
def forward(
self,
x,
algorithm="bruteforce-blas",
dist="euclidean",
exclude_self=False,
):
r"""
Forward computation.
......@@ -124,8 +132,9 @@ class KNNGraph(nn.Module):
DGLGraph
A DGLGraph without features.
"""
return knn_graph(x, self.k, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
return knn_graph(
x, self.k, algorithm=algorithm, dist=dist, exclude_self=exclude_self
)
class SegmentedKNNGraph(nn.Module):
......@@ -172,13 +181,20 @@ class SegmentedKNNGraph(nn.Module):
>>>
"""
def __init__(self, k):
super(SegmentedKNNGraph, self).__init__()
self.k = k
#pylint: disable=invalid-name
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
# pylint: disable=invalid-name
def forward(
self,
x,
segs,
algorithm="bruteforce-blas",
dist="euclidean",
exclude_self=False,
):
r"""Forward computation.
Parameters
......@@ -240,8 +256,14 @@ class SegmentedKNNGraph(nn.Module):
A batched DGLGraph without features.
"""
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
return segmented_knn_graph(
x,
self.k,
segs,
algorithm=algorithm,
dist=dist,
exclude_self=exclude_self,
)
class RadiusGraph(nn.Module):
......@@ -316,16 +338,21 @@ class RadiusGraph(nn.Module):
[0.7000],
[0.2828]])
"""
#pylint: disable=invalid-name
def __init__(self, r, p=2, self_loop=False,
compute_mode='donot_use_mm_for_euclid_dist'):
# pylint: disable=invalid-name
def __init__(
self,
r,
p=2,
self_loop=False,
compute_mode="donot_use_mm_for_euclid_dist",
):
super(RadiusGraph, self).__init__()
self.r = r
self.p = p
self.self_loop = self_loop
self.compute_mode = compute_mode
#pylint: disable=invalid-name
# pylint: disable=invalid-name
def forward(self, x, get_distances=False):
r"""
Forward computation.
......@@ -351,5 +378,6 @@ class RadiusGraph(nn.Module):
The distances for the edges in the constructed graph. The distances
are in the same order as edge IDs.
"""
return radius_graph(x, self.r, self.p, self.self_loop,
self.compute_mode, get_distances)
return radius_graph(
x, self.r, self.p, self.self_loop, self.compute_mode, get_distances
)
"""Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import numpy as np
import torch as th
import torch.nn as nn
import numpy as np
from ...backend import pytorch as F
from ...base import dgl_warning
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
from ...readout import (
broadcast_nodes,
max_nodes,
mean_nodes,
softmax_nodes,
sum_nodes,
topk_nodes,
)
__all__ = [
"SumPooling",
"AvgPooling",
"MaxPooling",
"SortPooling",
"GlobalAttentionPooling",
"Set2Set",
"SetTransformerEncoder",
"SetTransformerDecoder",
"WeightAndSum",
]
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in a graph.
......@@ -67,6 +81,7 @@ class SumPooling(nn.Module):
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511],
[1.0608, 1.2080, 2.1780, 2.7849, 2.5420]])
"""
def __init__(self):
super(SumPooling, self).__init__()
......@@ -90,8 +105,8 @@ class SumPooling(nn.Module):
batch size of input graphs.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = sum_nodes(graph, "h")
return readout
......@@ -148,6 +163,7 @@ class AvgPooling(nn.Module):
tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837],
[0.2652, 0.3020, 0.5445, 0.6962, 0.6355]])
"""
def __init__(self):
super(AvgPooling, self).__init__()
......@@ -171,8 +187,8 @@ class AvgPooling(nn.Module):
:math:`B` refers to the batch size of input graphs.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = mean_nodes(graph, "h")
return readout
......@@ -229,6 +245,7 @@ class MaxPooling(nn.Module):
tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118],
[0.5278, 0.6365, 0.9990, 0.9028, 0.8945]])
"""
def __init__(self):
super(MaxPooling, self).__init__()
......@@ -250,8 +267,8 @@ class MaxPooling(nn.Module):
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = max_nodes(graph, "h")
return readout
......@@ -316,6 +333,7 @@ class SortPooling(nn.Module):
[0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
0.9028]])
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
......@@ -342,10 +360,11 @@ class SortPooling(nn.Module):
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat, _ = feat.sort(dim=-1)
graph.ndata['h'] = feat
graph.ndata["h"] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, sortby=-1)[0].view(
-1, self.k * feat.shape[-1])
ret = topk_nodes(graph, "h", self.k, sortby=-1)[0].view(
-1, self.k * feat.shape[-1]
)
return ret
......@@ -414,6 +433,7 @@ class GlobalAttentionPooling(nn.Module):
on how to use GatedGraphConv and GlobalAttentionPooling layer to build a Graph Neural
Networks that can solve Soduku.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
......@@ -445,16 +465,18 @@ class GlobalAttentionPooling(nn.Module):
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
assert (
gate.shape[-1] == 1
), "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')
graph.ndata["gate"] = gate
gate = softmax_nodes(graph, "gate")
graph.ndata.pop("gate")
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
graph.ndata["r"] = feat * gate
readout = sum_nodes(graph, "r")
graph.ndata.pop("r")
if get_attention:
return readout, gate
......@@ -540,6 +562,7 @@ class Set2Set(nn.Module):
mpnn_predictor.py>`__
on how to use DGL's Set2Set layer in graph property prediction applications.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
......@@ -574,8 +597,10 @@ class Set2Set(nn.Module):
with graph.local_scope():
batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)))
h = (
feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
)
q_star = feat.new_zeros(batch_size, self.output_dim)
......@@ -583,10 +608,10 @@ class Set2Set(nn.Module):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
graph.ndata["e"] = e
alpha = softmax_nodes(graph, "e")
graph.ndata["r"] = feat * alpha
readout = sum_nodes(graph, "r")
q_star = th.cat([q, readout], dim=-1)
return q_star
......@@ -595,12 +620,12 @@ class Set2Set(nn.Module):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
summary = 'n_iters={n_iters}'
summary = "n_iters={n_iters}"
return summary.format(**self.__dict__)
def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" Generate binary mask array for given x and y input pairs.
"""Generate binary mask array for given x and y input pairs.
Parameters
----------
......@@ -620,9 +645,13 @@ def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
"""
device = lengths_x.device
# x_mask: (batch_size, max_len_x)
x_mask = th.arange(max_len_x, device=device).unsqueeze(0) < lengths_x.unsqueeze(1)
x_mask = th.arange(max_len_x, device=device).unsqueeze(
0
) < lengths_x.unsqueeze(1)
# y_mask: (batch_size, max_len_y)
y_mask = th.arange(max_len_y, device=device).unsqueeze(0) < lengths_y.unsqueeze(1)
y_mask = th.arange(max_len_y, device=device).unsqueeze(
0
) < lengths_y.unsqueeze(1)
# mask: (batch_size, 1, max_len_x, max_len_y)
mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)
return mask
......@@ -650,7 +679,10 @@ class MultiHeadAttention(nn.Module):
-----
This module was used in SetTransformer layer.
"""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
......@@ -664,7 +696,7 @@ class MultiHeadAttention(nn.Module):
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
nn.Linear(d_ff, d_model),
)
self.droph = nn.Dropout(dropouth)
self.dropa = nn.Dropout(dropouta)
......@@ -710,25 +742,28 @@ class MultiHeadAttention(nn.Module):
values = F.pad_packed_tensor(values, lengths_mem, 0)
# attention score with shape (B, num_heads, max_len_x, max_len_mem)
e = th.einsum('bxhd,byhd->bhxy', queries, keys)
e = th.einsum("bxhd,byhd->bhxy", queries, keys)
# normalize
e = e / np.sqrt(self.d_head)
# generate mask
mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)
e = e.masked_fill(mask == 0, -float('inf'))
e = e.masked_fill(mask == 0, -float("inf"))
# apply softmax
alpha = th.softmax(e, dim=-1)
# the following line addresses the NaN issue, see
# https://github.com/dmlc/dgl/issues/2657
alpha = alpha.masked_fill(mask == 0, 0.)
alpha = alpha.masked_fill(mask == 0, 0.0)
# sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values)
out = th.einsum("bhxy,byhd->bxhd", alpha, values)
# project to output
out = self.proj_o(
out.contiguous().view(batch_size, max_len_x, self.num_heads * self.d_head))
out.contiguous().view(
batch_size, max_len_x, self.num_heads * self.d_head
)
)
# pack tensor
out = F.pack_padded_tensor(out, lengths_x)
......@@ -764,10 +799,19 @@ class SetAttentionBlock(nn.Module):
-----
This module was used in SetTransformer layer.
"""
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(SetAttentionBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.mha = MultiHeadAttention(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
def forward(self, feat, lengths):
"""
......@@ -808,19 +852,32 @@ class InducedSetAttentionBlock(nn.Module):
-----
This module was used in SetTransformer layer.
"""
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, m, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(InducedSetAttentionBlock, self).__init__()
self.m = m
if m == 1:
dgl_warning("if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
dgl_warning(
"if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training."
)
self.d_model = d_model
self.inducing_points = nn.Parameter(
th.FloatTensor(m, d_model)
self.inducing_points = nn.Parameter(th.FloatTensor(m, d_model))
self.mha = nn.ModuleList(
[
MultiHeadAttention(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
for _ in range(2)
]
)
self.mha = nn.ModuleList([
MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta) for _ in range(2)])
self.reset_parameters()
def reset_parameters(self):
......@@ -852,8 +909,10 @@ class InducedSetAttentionBlock(nn.Module):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.inducing_points.shape[0], self.inducing_points.shape[1])
return 'InducedVector: ' + shape_str
shape_str = "({}, {})".format(
self.inducing_points.shape[0], self.inducing_points.shape[1]
)
return "InducedVector: " + shape_str
class PMALayer(nn.Module):
......@@ -881,23 +940,32 @@ class PMALayer(nn.Module):
-----
This module was used in SetTransformer layer.
"""
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, k, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(PMALayer, self).__init__()
self.k = k
if k == 1:
dgl_warning("if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training.")
dgl_warning(
"if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training."
)
self.d_model = d_model
self.seed_vectors = nn.Parameter(
th.FloatTensor(k, d_model)
self.seed_vectors = nn.Parameter(th.FloatTensor(k, d_model))
self.mha = MultiHeadAttention(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropouth),
nn.Linear(d_ff, d_model)
nn.Linear(d_ff, d_model),
)
self.reset_parameters()
......@@ -929,8 +997,10 @@ class PMALayer(nn.Module):
"""Set the extra representation of the module.
which will come into effect when printing the model.
"""
shape_str = '({}, {})'.format(self.seed_vectors.shape[0], self.seed_vectors.shape[1])
return 'SeedVector: ' + shape_str
shape_str = "({}, {})".format(
self.seed_vectors.shape[0], self.seed_vectors.shape[1]
)
return "SeedVector: " + shape_str
class SetTransformerEncoder(nn.Module):
......@@ -1018,27 +1088,57 @@ class SetTransformerEncoder(nn.Module):
representation instead out graphwise representation, and the SetTransformerDecoder
would return a graph readout tensor.
"""
def __init__(self, d_model, n_heads, d_head, d_ff,
n_layers=1, block_type='sab', m=None, dropouth=0., dropouta=0.):
def __init__(
self,
d_model,
n_heads,
d_head,
d_ff,
n_layers=1,
block_type="sab",
m=None,
dropouth=0.0,
dropouta=0.0,
):
super(SetTransformerEncoder, self).__init__()
self.n_layers = n_layers
self.block_type = block_type
self.m = m
layers = []
if block_type == 'isab' and m is None:
raise KeyError('The number of inducing points is not specified in ISAB block.')
if block_type == "isab" and m is None:
raise KeyError(
"The number of inducing points is not specified in ISAB block."
)
for _ in range(n_layers):
if block_type == 'sab':
if block_type == "sab":
layers.append(
SetAttentionBlock(d_model, n_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
elif block_type == 'isab':
SetAttentionBlock(
d_model,
n_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
elif block_type == "isab":
layers.append(
InducedSetAttentionBlock(m, d_model, n_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
InducedSetAttentionBlock(
m,
d_model,
n_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
else:
raise KeyError("Unrecognized block type {}: we only support sab/isab")
raise KeyError(
"Unrecognized block type {}: we only support sab/isab"
)
self.layers = nn.ModuleList(layers)
......@@ -1136,18 +1236,43 @@ class SetTransformerDecoder(nn.Module):
--------
SetTransformerEncoder
"""
def __init__(self, d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0., dropouta=0.):
def __init__(
self,
d_model,
num_heads,
d_head,
d_ff,
n_layers,
k,
dropouth=0.0,
dropouta=0.0,
):
super(SetTransformerDecoder, self).__init__()
self.n_layers = n_layers
self.k = k
self.d_model = d_model
self.pma = PMALayer(k, d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.pma = PMALayer(
k,
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
layers = []
for _ in range(n_layers):
layers.append(
SetAttentionBlock(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta))
SetAttentionBlock(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
self.layers = nn.ModuleList(layers)
......@@ -1236,12 +1361,12 @@ class WeightAndSum(nn.Module):
gcn_predictor.py>`__
to understand how to use WeightAndSum layer to get the graph readout output.
"""
def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1),
nn.Sigmoid()
nn.Linear(in_feats, 1), nn.Sigmoid()
)
def forward(self, g, feats):
......@@ -1261,8 +1386,8 @@ class WeightAndSum(nn.Module):
Representations for B molecules
"""
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
g.ndata["h"] = feats
g.ndata["w"] = self.atom_weighting(g.ndata["h"])
h_g_sum = sum_nodes(g, "h", "w")
return h_g_sum
"""Heterograph NN modules"""
from functools import partial
import torch as th
import torch.nn as nn
from ...base import DGLError
__all__ = ['HeteroGraphConv', 'HeteroLinear', 'HeteroEmbedding']
__all__ = ["HeteroGraphConv", "HeteroLinear", "HeteroEmbedding"]
class HeteroGraphConv(nn.Module):
r"""A generic module for computing convolution on heterogeneous graphs.
......@@ -120,7 +123,8 @@ class HeteroGraphConv(nn.Module):
mods : dict[str, nn.Module]
Modules associated with every edge types.
"""
def __init__(self, mods, aggregate='sum'):
def __init__(self, mods, aggregate="sum"):
super(HeteroGraphConv, self).__init__()
self.mod_dict = mods
mods = {str(k): v for k, v in mods.items()}
......@@ -132,7 +136,9 @@ class HeteroGraphConv(nn.Module):
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None)
set_allow_zero_in_degree_fn = getattr(
v, "set_allow_zero_in_degree", None
)
if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str):
......@@ -148,7 +154,7 @@ class HeteroGraphConv(nn.Module):
# etype is canonical
_, etype, _ = etype
return self.mod_dict[etype]
raise KeyError('Cannot find module with edge type %s' % etype)
raise KeyError("Cannot find module with edge type %s" % etype)
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation
......@@ -175,13 +181,15 @@ class HeteroGraphConv(nn.Module):
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
outputs = {nty: [] for nty in g.dsttypes}
if isinstance(inputs, tuple) or g.is_block:
if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs
else:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
dst_inputs = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
......@@ -191,7 +199,8 @@ class HeteroGraphConv(nn.Module):
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
**mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata)
else:
for stype, etype, dtype in g.canonical_etypes:
......@@ -202,7 +211,8 @@ class HeteroGraphConv(nn.Module):
rel_graph,
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
**mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata)
rsts = {}
for nty, alist in outputs.items():
......@@ -210,29 +220,36 @@ class HeteroGraphConv(nn.Module):
rsts[nty] = self.agg_fn(alist, nty)
return rsts
def _max_reduce_func(inputs, dim):
return th.max(inputs, dim=dim)[0]
def _min_reduce_func(inputs, dim):
return th.min(inputs, dim=dim)[0]
def _sum_reduce_func(inputs, dim):
return th.sum(inputs, dim=dim)
def _mean_reduce_func(inputs, dim):
return th.mean(inputs, dim=dim)
def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return th.stack(inputs, dim=1)
def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0)
def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
......@@ -249,24 +266,27 @@ def get_aggregate_fn(agg):
Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor.
"""
if agg == 'sum':
if agg == "sum":
fn = _sum_reduce_func
elif agg == 'max':
elif agg == "max":
fn = _max_reduce_func
elif agg == 'min':
elif agg == "min":
fn = _min_reduce_func
elif agg == 'mean':
elif agg == "mean":
fn = _mean_reduce_func
elif agg == 'stack':
elif agg == "stack":
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
raise DGLError(
"Invalid cross type aggregator. Must be one of "
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
)
if agg == "stack":
return _stack_agg_func
else:
return partial(_agg_func, fn=fn)
class HeteroLinear(nn.Module):
"""Apply linear transformations on heterogeneous inputs.
......@@ -294,6 +314,7 @@ class HeteroLinear(nn.Module):
>>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3])
"""
def __init__(self, in_size, out_size, bias=True):
super(HeteroLinear, self).__init__()
......@@ -320,6 +341,7 @@ class HeteroLinear(nn.Module):
return out_feat
class HeteroEmbedding(nn.Module):
"""Create a heterogeneous embedding table.
......@@ -356,6 +378,7 @@ class HeteroEmbedding(nn.Module):
>>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4])
"""
def __init__(self, num_embeddings, embedding_dim):
super(HeteroEmbedding, self).__init__()
......@@ -374,7 +397,9 @@ class HeteroEmbedding(nn.Module):
dict[key, Tensor]
Heterogeneous embedding table
"""
return {self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()}
return {
self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()
}
def reset_parameters(self):
"""
......
"""Various commonly used linear modules"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import math
import torch
import torch.nn as nn
from ...ops import segment_mm, gather_mm
from ...ops import gather_mm, segment_mm
__all__ = ["TypedLinear"]
__all__ = ['TypedLinear']
class TypedLinear(nn.Module):
r"""Linear transformation according to types.
......@@ -81,35 +83,43 @@ class TypedLinear(nn.Module):
>>> print(y.shape)
torch.Size([100, 64])
"""
def __init__(self, in_size, out_size, num_types,
regularizer=None, num_bases=None):
def __init__(
self, in_size, out_size, num_types, regularizer=None, num_bases=None
):
super().__init__()
self.in_size = in_size
self.out_size = out_size
self.num_types = num_types
if regularizer is None:
self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))
elif regularizer == 'basis':
elif regularizer == "basis":
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
raise ValueError(
'Missing "num_bases" for basis regularization.'
)
self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))
self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))
self.num_bases = num_bases
elif regularizer == 'bdd':
elif regularizer == "bdd":
if num_bases is None:
raise ValueError('Missing "num_bases" for bdd regularization.')
if in_size % num_bases != 0 or out_size % num_bases != 0:
raise ValueError(
'Input and output sizes must be divisible by num_bases.'
"Input and output sizes must be divisible by num_bases."
)
self.submat_in = in_size // num_bases
self.submat_out = out_size // num_bases
self.W = nn.Parameter(torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out))
self.W = nn.Parameter(
torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out
)
)
self.num_bases = num_bases
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
self.regularizer = regularizer
self.reset_parameters()
......@@ -118,28 +128,46 @@ class TypedLinear(nn.Module):
with torch.no_grad():
# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size
if self.regularizer is None:
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
elif self.regularizer == 'basis':
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
nn.init.xavier_uniform_(self.coeff, gain=nn.init.calculate_gain('relu'))
elif self.regularizer == 'bdd':
nn.init.uniform_(self.W, -1/math.sqrt(self.submat_in), 1/math.sqrt(self.submat_in))
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.in_size),
1 / math.sqrt(self.in_size),
)
elif self.regularizer == "basis":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.in_size),
1 / math.sqrt(self.in_size),
)
nn.init.xavier_uniform_(
self.coeff, gain=nn.init.calculate_gain("relu")
)
elif self.regularizer == "bdd":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.submat_in),
1 / math.sqrt(self.submat_in),
)
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
def get_weight(self):
"""Get type-wise weight"""
if self.regularizer is None:
return self.W
elif self.regularizer == 'basis':
elif self.regularizer == "basis":
W = self.W.view(self.num_bases, self.in_size * self.out_size)
return (self.coeff @ W).view(self.num_types, self.in_size, self.out_size)
elif self.regularizer == 'bdd':
return (self.coeff @ W).view(
self.num_types, self.in_size, self.out_size
)
elif self.regularizer == "bdd":
return self.W
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
def forward(self, x, x_type, sorted_by_type=False):
"""Forward computation.
......@@ -161,23 +189,35 @@ class TypedLinear(nn.Module):
The transformed output tensor. Shape: (N, D2)
"""
w = self.get_weight()
if self.regularizer == 'bdd':
w = w.index_select(0, x_type).view(-1, self.submat_in, self.submat_out)
if self.regularizer == "bdd":
w = w.index_select(0, x_type).view(
-1, self.submat_in, self.submat_out
)
x = x.view(-1, 1, self.submat_in)
return torch.bmm(x, w).view(-1, self.out_size)
elif sorted_by_type:
pos_l = torch.searchsorted(x_type, torch.arange(self.num_types, device=x.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(x_type)], device=x.device)])
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
pos_l = torch.searchsorted(
x_type, torch.arange(self.num_types, device=x.device)
)
pos_r = torch.cat(
[pos_l[1:], torch.tensor([len(x_type)], device=x.device)]
)
seglen = (
pos_r - pos_l
).cpu() # XXX(minjie): cause device synchronize
return segment_mm(x, w, seglen_a=seglen)
else:
return gather_mm(x, w, idx_b=x_type)
def __repr__(self):
if self.regularizer is None:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types})')
return (
f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f"num_types={self.num_types})"
)
else:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types}, regularizer={self.regularizer}, '
f'num_bases={self.num_bases})')
return (
f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f"num_types={self.num_types}, regularizer={self.regularizer}, "
f"num_bases={self.num_bases})"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment