Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
......@@ -63,12 +63,10 @@ class DenseGraphConv(nn.Module):
--------
`GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
bias=True,
activation=None):
def __init__(
self, in_feats, out_feats, norm="both", bias=True, activation=None
):
super(DenseGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -77,7 +75,7 @@ class DenseGraphConv(nn.Module):
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
self._activation = activation
......@@ -114,7 +112,7 @@ class DenseGraphConv(nn.Module):
dst_degrees = adj.sum(dim=1).clamp(min=1)
feat_src = feat
if self._norm == 'both':
if self._norm == "both":
norm_src = th.pow(src_degrees, -0.5)
shp = norm_src.shape + (1,) * (feat.dim() - 1)
norm_src = th.reshape(norm_src, shp).to(feat.device)
......@@ -129,10 +127,10 @@ class DenseGraphConv(nn.Module):
rst = adj @ feat_src
rst = th.matmul(rst, self.weight)
if self._norm != 'none':
if self._norm == 'both':
if self._norm != "none":
if self._norm == "both":
norm_dst = th.pow(dst_degrees, -0.5)
else: # right
else: # right
norm_dst = 1.0 / dst_degrees
shp = norm_dst.shape + (1,) * (feat.dim() - 1)
norm_dst = th.reshape(norm_dst, shp).to(feat.device)
......
"""Torch Module for DenseSAGEConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn
from ....utils import check_eq_shape
......@@ -56,13 +57,16 @@ class DenseSAGEConv(nn.Module):
--------
`SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__
"""
def __init__(self,
in_feats,
out_feats,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
def __init__(
self,
in_feats,
out_feats,
feat_drop=0.0,
bias=True,
norm=None,
activation=None,
):
super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -83,7 +87,7 @@ class DenseSAGEConv(nn.Module):
-----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, adj, feat):
......
"""Torch Module for Directional Graph Networks Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
"""directional average aggregation"""
h_mod = torch.mul(h, (
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30)).unsqueeze(-1))
h_mod = torch.mul(
h,
(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1),
)
return torch.sum(h_mod, dim=1)
def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):
"""directional derivative aggregation"""
eig_w = ((
eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30
eig_w = (
(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1)
h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
for k in range(1, 4):
AGGREGATORS[f'dir{k}-av'] = partial(aggregate_dir_av, eig_idx=k-1)
AGGREGATORS[f'dir{k}-dx'] = partial(aggregate_dir_dx, eig_idx=k-1)
AGGREGATORS[f"dir{k}-av"] = partial(aggregate_dir_av, eig_idx=k - 1)
AGGREGATORS[f"dir{k}-dx"] = partial(aggregate_dir_dx, eig_idx=k - 1)
class DGNConvTower(PNAConvTower):
"""A single DGN tower with modified reduce function"""
def message(self, edges):
"""message function for DGN layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}
f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {
"msg": self.M(f),
"eig_s": edges.src["eig"],
"eig_d": edges.dst["eig"],
}
def reduce_func(self, nodes):
"""reduce function for DGN layer"""
h_in = nodes.data['h']
eig_s = nodes.mailbox['eig_s']
eig_d = nodes.mailbox['eig_d']
msg = nodes.mailbox['msg']
h_in = nodes.data["h"]
eig_s = nodes.mailbox["eig_s"]
eig_d = nodes.mailbox["eig_d"]
msg = nodes.mailbox["msg"]
degree = msg.size(1)
h = []
for agg in self.aggregators:
if agg.startswith('dir'):
if agg.endswith('av'):
if agg.startswith("dir"):
if agg.endswith("av"):
h.append(AGGREGATORS[agg](msg, eig_s, eig_d))
else:
h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))
else:
h.append(AGGREGATORS[agg](msg))
h = torch.cat(h, dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
h = torch.cat(
[
SCALERS[scaler](h, D=degree, delta=self.delta)
if scaler != "identity"
else h
for scaler in self.scalers
],
dim=1,
)
return {"h_neigh": h}
class DGNConv(PNAConv):
r"""Directional Graph Network Layer from `Directional Graph Networks
......@@ -154,24 +187,49 @@ class DGNConv(PNAConv):
>>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat, eig_vec=eig)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(DGNConv, self).__init__(
in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual
in_size,
out_size,
aggregators,
scalers,
delta,
dropout,
num_towers,
edge_feat_size,
residual,
)
self.towers = nn.ModuleList([
DGNConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.towers = nn.ModuleList(
[
DGNConvTower(
self.tower_in_size,
self.tower_out_size,
aggregators,
scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.use_eig_vec = False
for aggr in aggregators:
if aggr.startswith('dir'):
if aggr.startswith("dir"):
self.use_eig_vec = True
break
......@@ -203,5 +261,5 @@ class DGNConv(PNAConv):
"""
with graph.local_scope():
if self.use_eig_vec:
graph.ndata['eig'] = eig_vec
graph.ndata["eig"] = eig_vec
return super().forward(graph, node_feat, edge_feat)
......@@ -2,6 +2,7 @@
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from .... import function as fn
......@@ -47,6 +48,7 @@ class EGNNConv(nn.Module):
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
"""
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
......@@ -62,21 +64,21 @@ class EGNNConv(nn.Module):
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
act_fn,
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
nn.Linear(hidden_size, out_size),
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def message(self, edges):
......@@ -84,16 +86,23 @@ class EGNNConv(nn.Module):
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
[
edges.src["h"],
edges.dst["h"],
edges.data["radial"],
edges.data["a"],
],
dim=-1,
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1
)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"]
return {'msg_x': msg_x, 'msg_h': msg_h}
return {"msg_x": msg_x, "msg_h": msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r"""
......@@ -126,27 +135,29 @@ class EGNNConv(nn.Module):
"""
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
graph.ndata["h"] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
graph.ndata["x"] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.edata["a"] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
graph.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
graph.edata["radial"] = (
graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1)
)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.edata["x_diff"] = graph.edata["x_diff"] / (
graph.edata["radial"].sqrt() + 1e-30
)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh"))
graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh"))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"]
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))
x = coord_feat + x_neigh
return h, x
......@@ -58,12 +58,7 @@ class GatedGraphConv(nn.Module):
0.1342, 0.0425]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):
super(GatedGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -87,7 +82,7 @@ class GatedGraphConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
self.gru.reset_parameters()
for linear in self.linears:
init.xavier_normal_(linear.weight, gain=gain)
......@@ -134,36 +129,44 @@ class GatedGraphConv(nn.Module):
is the output feature size.
"""
with graph.local_scope():
assert graph.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \
assert graph.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument"
)
if self._n_etypes != 1:
assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \
"edge type indices out of range [0, {})".format(
self._n_etypes)
assert (
etypes.min() >= 0 and etypes.max() < self._n_etypes
), "edge type indices out of range [0, {})".format(
self._n_etypes
)
zero_pad = feat.new_zeros(
(feat.shape[0], self._out_feats - feat.shape[1]))
(feat.shape[0], self._out_feats - feat.shape[1])
)
feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps):
if self._n_etypes == 1 and etypes is None:
# Fast path when graph has only one edge type
graph.ndata['h'] = self.linears[0](feat)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.ndata["h"] = self.linears[0](feat)
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
else:
graph.ndata['h'] = feat
graph.ndata["h"] = feat
for i in range(self._n_etypes):
eids = th.nonzero(
etypes == i, as_tuple=False).view(-1).type(graph.idtype)
eids = (
th.nonzero(etypes == i, as_tuple=False)
.view(-1)
.type(graph.idtype)
)
if len(eids) > 0:
graph.apply_edges(
lambda edges: {
'W_e*h': self.linears[i](edges.src['h'])},
eids
"W_e*h": self.linears[i](edges.src["h"])
},
eids,
)
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.update_all(fn.copy_e("W_e*h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
feat = self.gru(a, feat)
return feat
......@@ -4,10 +4,11 @@ import torch as th
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235
class GATv2Conv(nn.Module):
......@@ -134,18 +135,21 @@ class GATv2Conv(nn.Module):
[-1.1850, 0.1123],
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False):
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False,
):
super(GATv2Conv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -153,17 +157,21 @@ class GATv2Conv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=bias)
self._in_dst_feats, out_feats * num_heads, bias=bias
)
else:
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
if share_weights:
self.fc_dst = self.fc_src
else:
self.fc_dst = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
......@@ -171,11 +179,12 @@ class GATv2Conv(nn.Module):
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias)
self._in_dst_feats, num_heads * out_feats, bias=bias
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
self.activation = activation
self.share_weights = share_weights
self.bias = bias
......@@ -192,7 +201,7 @@ class GATv2Conv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
if self.bias:
nn.init.constant_(self.fc_src.bias, 0)
......@@ -256,53 +265,70 @@ class GATv2Conv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if self.share_weights:
feat_dst = feat_src
else:
feat_dst = self.fc_dst(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
graph.dstdata.update({'er': feat_dst})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
graph.srcdata.update(
{"el": feat_src}
) # (num_src_edge, num_heads, out_dim)
graph.dstdata.update({"er": feat_dst})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(
graph.edata.pop("e")
) # (num_src_edge, num_heads, out_dim)
e = (
(e * self.attn).sum(dim=-1).unsqueeze(dim=2)
) # (num_edge, num_heads, 1)
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
graph.edata["a"] = self.attn_drop(
edge_softmax(graph, e)
) # (num_edge, num_heads)
# message passing
graph.update_all(fn.u_mul_e('el', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("el", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
......@@ -104,15 +104,17 @@ class GCN2Conv(nn.Module):
"""
def __init__(self,
in_feats,
layer,
alpha=0.1,
lambda_=1,
project_initial_features=True,
allow_zero_in_degree=False,
bias=True,
activation=None):
def __init__(
self,
in_feats,
layer,
alpha=0.1,
lambda_=1,
project_initial_features=True,
allow_zero_in_degree=False,
bias=True,
activation=None,
):
super().__init__()
self._in_feats = in_feats
......@@ -131,7 +133,8 @@ class GCN2Conv(nn.Module):
self.register_parameter("weight2", None)
else:
self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats))
th.Tensor(self._in_feats, self._in_feats)
)
if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats))
......@@ -233,7 +236,7 @@ class GCN2Conv(nn.Module):
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight)
edge_weight = EdgeWeightNorm("both")(graph, edge_weight)
if edge_weight is None:
feat = feat * norm
......@@ -255,14 +258,26 @@ class GCN2Conv(nn.Module):
if self._project_initial_features:
rst = feat.add_(feat_0)
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
else:
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
rst += th.addmm(
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
feat_0,
feat_0,
self.weight2,
beta=(1 - self.beta),
alpha=self.beta,
)
if self._bias:
......
......@@ -7,6 +7,7 @@ from torch import nn
from .... import function as fn
from ....utils import expand_as_pair
class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
......@@ -45,21 +46,19 @@ class GINEConv(nn.Module):
>>> print(res.shape)
torch.Size([4, 20])
"""
def __init__(self,
apply_func=None,
init_eps=0,
learn_eps=False):
def __init__(self, apply_func=None, init_eps=0, learn_eps=False):
super(GINEConv, self).__init__()
self.apply_func = apply_func
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self.register_buffer("eps", th.FloatTensor([init_eps]))
def message(self, edges):
r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])}
return {"m": F.relu(edges.src["hn"] + edges.data["he"])}
def forward(self, graph, node_feat, edge_feat):
r"""Forward computation.
......@@ -89,10 +88,10 @@ class GINEConv(nn.Module):
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src
graph.edata['he'] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
graph.srcdata["hn"] = feat_src
graph.edata["he"] = edge_feat
graph.update_all(self.message, fn.sum("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd"""
@staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn
......@@ -40,19 +43,25 @@ class InvertibleCheckpoint(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible")
raise RuntimeError(
"InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible"
)
# retrieve input and output tensor nodes
if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \
for more than once.")
raise RuntimeError(
"Trying to perform backward on the InvertibleCheckpoint \
for more than once."
)
inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop()
# reconstruct input node features
with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:]))
inputs_inverted = ctx.fn_inverse(
*((inputs[0], outputs) + inputs[2:])
)
# clear memory of outputs
outputs.storage().resize_(0)
......@@ -72,11 +81,16 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
filtered_detached_inputs = tuple(
filter(
lambda x: getattr(x, "requires_grad", False), detached_inputs
)
)
gradients = torch.autograd.grad(
outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs,
)
input_gradients = []
i = 0
......@@ -87,7 +101,7 @@ class InvertibleCheckpoint(torch.autograd.Function):
else:
input_gradients.append(None)
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]
gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]
return (None, None, None) + gradients
......@@ -157,6 +171,7 @@ class GroupRevRes(nn.Module):
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
"""
def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList()
......@@ -173,7 +188,9 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:])
......@@ -192,13 +209,15 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
xs = []
for i in range(self.groups-1, -1, -1):
for i in range(self.groups - 1, -1, -1):
if i != 0:
y_in = ys[i-1]
y_in = ys[i - 1]
else:
y_in = sum(xs)
......@@ -232,6 +251,7 @@ class GroupRevRes(nn.Module):
self._forward,
self._inverse,
len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad])))
*(args + tuple([p for p in self.parameters() if p.requires_grad]))
)
return y
This diff is collapsed.
This diff is collapsed.
......@@ -6,6 +6,7 @@ from torch import nn
from .... import function as fn
from ..linear import TypedLinear
class RelGraphConv(nn.Module):
r"""Relational graph convolution layer from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
......@@ -94,21 +95,26 @@ class RelGraphConv(nn.Module):
[-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer=None,
num_bases=None,
bias=True,
activation=None,
self_loop=True,
dropout=0.0,
layer_norm=False):
def __init__(
self,
in_feat,
out_feat,
num_rels,
regularizer=None,
num_bases=None,
bias=True,
activation=None,
self_loop=True,
dropout=0.0,
layer_norm=False,
):
super().__init__()
if regularizer is not None and num_bases is None:
num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.linear_r = TypedLinear(
in_feat, out_feat, num_rels, regularizer, num_bases
)
self.bias = bias
self.activation = activation
self.self_loop = self_loop
......@@ -123,21 +129,25 @@ class RelGraphConv(nn.Module):
# the module only about graph convolution.
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
self.layer_norm_weight = nn.LayerNorm(
out_feat, elementwise_affine=True
)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout)
def message(self, edges):
"""Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted)
if 'norm' in edges.data:
m = m * edges.data['norm']
return {'m' : m}
m = self.linear_r(edges.src["h"], edges.data["etype"], self.presorted)
if "norm" in edges.data:
m = m * edges.data["norm"]
return {"m": m}
def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Forward computation.
......@@ -165,20 +175,20 @@ class RelGraphConv(nn.Module):
"""
self.presorted = presorted
with g.local_scope():
g.srcdata['h'] = feat
g.srcdata["h"] = feat
if norm is not None:
g.edata['norm'] = norm
g.edata['etype'] = etypes
g.edata["norm"] = norm
g.edata["etype"] = etypes
# message passing
g.update_all(self.message, fn.sum('m', 'h'))
g.update_all(self.message, fn.sum("m", "h"))
# apply bias and activation
h = g.dstdata['h']
h = g.dstdata["h"]
if self.layer_norm:
h = self.layer_norm_weight(h)
if self.bias:
h = h + self.h_bias
if self.self_loop:
h = h + feat[:g.num_dst_nodes()] @ self.loop_weight
h = h + feat[: g.num_dst_nodes()] @ self.loop_weight
if self.activation:
h = self.activation(h)
h = self.dropout(h)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,4 +3,4 @@
from .gnnexplainer import GNNExplainer
__all__ = ['GNNExplainer']
__all__ = ["GNNExplainer"]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment