Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
......@@ -63,12 +63,10 @@ class DenseGraphConv(nn.Module):
--------
`GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
bias=True,
activation=None):
def __init__(
self, in_feats, out_feats, norm="both", bias=True, activation=None
):
super(DenseGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -77,7 +75,7 @@ class DenseGraphConv(nn.Module):
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_buffer('bias', None)
self.register_buffer("bias", None)
self.reset_parameters()
self._activation = activation
......@@ -114,7 +112,7 @@ class DenseGraphConv(nn.Module):
dst_degrees = adj.sum(dim=1).clamp(min=1)
feat_src = feat
if self._norm == 'both':
if self._norm == "both":
norm_src = th.pow(src_degrees, -0.5)
shp = norm_src.shape + (1,) * (feat.dim() - 1)
norm_src = th.reshape(norm_src, shp).to(feat.device)
......@@ -129,8 +127,8 @@ class DenseGraphConv(nn.Module):
rst = adj @ feat_src
rst = th.matmul(rst, self.weight)
if self._norm != 'none':
if self._norm == 'both':
if self._norm != "none":
if self._norm == "both":
norm_dst = th.pow(dst_degrees, -0.5)
else: # right
norm_dst = 1.0 / dst_degrees
......
"""Torch Module for DenseSAGEConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn
from ....utils import check_eq_shape
......@@ -56,13 +57,16 @@ class DenseSAGEConv(nn.Module):
--------
`SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__
"""
def __init__(self,
def __init__(
self,
in_feats,
out_feats,
feat_drop=0.,
feat_drop=0.0,
bias=True,
norm=None,
activation=None):
activation=None,
):
super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -83,7 +87,7 @@ class DenseSAGEConv(nn.Module):
-----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, adj, feat):
......
"""Torch Module for Directional Graph Networks Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial
import torch
import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
"""directional average aggregation"""
h_mod = torch.mul(h, (
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30)).unsqueeze(-1))
h_mod = torch.mul(
h,
(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1),
)
return torch.sum(h_mod, dim=1)
def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):
"""directional derivative aggregation"""
eig_w = ((
eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) /
(torch.sum(
eig_w = (
(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True, dim=1) + 1e-30
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1)
h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
for k in range(1, 4):
AGGREGATORS[f'dir{k}-av'] = partial(aggregate_dir_av, eig_idx=k-1)
AGGREGATORS[f'dir{k}-dx'] = partial(aggregate_dir_dx, eig_idx=k-1)
AGGREGATORS[f"dir{k}-av"] = partial(aggregate_dir_av, eig_idx=k - 1)
AGGREGATORS[f"dir{k}-dx"] = partial(aggregate_dir_dx, eig_idx=k - 1)
class DGNConvTower(PNAConvTower):
"""A single DGN tower with modified reduce function"""
def message(self, edges):
"""message function for DGN layer"""
if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1)
return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']}
f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {
"msg": self.M(f),
"eig_s": edges.src["eig"],
"eig_d": edges.dst["eig"],
}
def reduce_func(self, nodes):
"""reduce function for DGN layer"""
h_in = nodes.data['h']
eig_s = nodes.mailbox['eig_s']
eig_d = nodes.mailbox['eig_d']
msg = nodes.mailbox['msg']
h_in = nodes.data["h"]
eig_s = nodes.mailbox["eig_s"]
eig_d = nodes.mailbox["eig_d"]
msg = nodes.mailbox["msg"]
degree = msg.size(1)
h = []
for agg in self.aggregators:
if agg.startswith('dir'):
if agg.endswith('av'):
if agg.startswith("dir"):
if agg.endswith("av"):
h.append(AGGREGATORS[agg](msg, eig_s, eig_d))
else:
h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))
else:
h.append(AGGREGATORS[agg](msg))
h = torch.cat(h, dim=1)
h = torch.cat([
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h
h = torch.cat(
[
SCALERS[scaler](h, D=degree, delta=self.delta)
if scaler != "identity"
else h
for scaler in self.scalers
], dim=1)
return {'h_neigh': h}
],
dim=1,
)
return {"h_neigh": h}
class DGNConv(PNAConv):
r"""Directional Graph Network Layer from `Directional Graph Networks
......@@ -154,24 +187,49 @@ class DGNConv(PNAConv):
>>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat, eig_vec=eig)
"""
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True):
def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(DGNConv, self).__init__(
in_size, out_size, aggregators, scalers, delta, dropout,
num_towers, edge_feat_size, residual
in_size,
out_size,
aggregators,
scalers,
delta,
dropout,
num_towers,
edge_feat_size,
residual,
)
self.towers = nn.ModuleList([
self.towers = nn.ModuleList(
[
DGNConvTower(
self.tower_in_size, self.tower_out_size,
aggregators, scalers, delta,
dropout=dropout, edge_feat_size=edge_feat_size
) for _ in range(num_towers)
])
self.tower_in_size,
self.tower_out_size,
aggregators,
scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.use_eig_vec = False
for aggr in aggregators:
if aggr.startswith('dir'):
if aggr.startswith("dir"):
self.use_eig_vec = True
break
......@@ -203,5 +261,5 @@ class DGNConv(PNAConv):
"""
with graph.local_scope():
if self.use_eig_vec:
graph.ndata['eig'] = eig_vec
graph.ndata["eig"] = eig_vec
return super().forward(graph, node_feat, edge_feat)
......@@ -2,6 +2,7 @@
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
import torch.nn as nn
from .... import function as fn
......@@ -47,6 +48,7 @@ class EGNNConv(nn.Module):
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
"""
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__()
......@@ -62,21 +64,21 @@ class EGNNConv(nn.Module):
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn,
nn.Linear(hidden_size, hidden_size),
act_fn
act_fn,
)
# \phi_h
self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, out_size)
nn.Linear(hidden_size, out_size),
)
# \phi_x
self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
act_fn,
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def message(self, edges):
......@@ -84,16 +86,23 @@ class EGNNConv(nn.Module):
# concat features for edge mlp
if self.edge_feat_size > 0:
f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']],
dim=-1
[
edges.src["h"],
edges.dst["h"],
edges.data["radial"],
edges.data["a"],
],
dim=-1,
)
else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1)
f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1
)
msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff']
msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"]
return {'msg_x': msg_x, 'msg_h': msg_h}
return {"msg_x": msg_x, "msg_h": msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r"""
......@@ -126,27 +135,29 @@ class EGNNConv(nn.Module):
"""
with graph.local_scope():
# node feature
graph.ndata['h'] = node_feat
graph.ndata["h"] = node_feat
# coordinate feature
graph.ndata['x'] = coord_feat
graph.ndata["x"] = coord_feat
# edge feature
if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat
graph.edata["a"] = edge_feat
# get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff'))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1)
graph.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
graph.edata["radial"] = (
graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1)
)
# normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30)
graph.edata["x_diff"] = graph.edata["x_diff"] / (
graph.edata["radial"].sqrt() + 1e-30
)
graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh'))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh'))
graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh"))
graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh"))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh']
h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"]
h = self.node_mlp(
torch.cat([node_feat, h_neigh], dim=-1)
)
h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))
x = coord_feat + x_neigh
return h, x
......@@ -58,12 +58,7 @@ class GatedGraphConv(nn.Module):
0.1342, 0.0425]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):
super(GatedGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -87,7 +82,7 @@ class GatedGraphConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero.
"""
gain = init.calculate_gain('relu')
gain = init.calculate_gain("relu")
self.gru.reset_parameters()
for linear in self.linears:
init.xavier_normal_(linear.weight, gain=gain)
......@@ -134,36 +129,44 @@ class GatedGraphConv(nn.Module):
is the output feature size.
"""
with graph.local_scope():
assert graph.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \
assert graph.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument"
)
if self._n_etypes != 1:
assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \
"edge type indices out of range [0, {})".format(
self._n_etypes)
assert (
etypes.min() >= 0 and etypes.max() < self._n_etypes
), "edge type indices out of range [0, {})".format(
self._n_etypes
)
zero_pad = feat.new_zeros(
(feat.shape[0], self._out_feats - feat.shape[1]))
(feat.shape[0], self._out_feats - feat.shape[1])
)
feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps):
if self._n_etypes == 1 and etypes is None:
# Fast path when graph has only one edge type
graph.ndata['h'] = self.linears[0](feat)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.ndata["h"] = self.linears[0](feat)
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
else:
graph.ndata['h'] = feat
graph.ndata["h"] = feat
for i in range(self._n_etypes):
eids = th.nonzero(
etypes == i, as_tuple=False).view(-1).type(graph.idtype)
eids = (
th.nonzero(etypes == i, as_tuple=False)
.view(-1)
.type(graph.idtype)
)
if len(eids) > 0:
graph.apply_edges(
lambda edges: {
'W_e*h': self.linears[i](edges.src['h'])},
eids
"W_e*h": self.linears[i](edges.src["h"])
},
eids,
)
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D)
graph.update_all(fn.copy_e("W_e*h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop("a") # (N, D)
feat = self.gru(a, feat)
return feat
......@@ -4,10 +4,11 @@ import torch as th
from torch import nn
from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235
class GATv2Conv(nn.Module):
......@@ -134,18 +135,21 @@ class GATv2Conv(nn.Module):
[-1.1850, 0.1123],
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
"""
def __init__(self,
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=True,
share_weights=False):
share_weights=False,
):
super(GATv2Conv, self).__init__()
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -153,17 +157,21 @@ class GATv2Conv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=bias)
self._in_dst_feats, out_feats * num_heads, bias=bias
)
else:
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
if share_weights:
self.fc_dst = self.fc_src
else:
self.fc_dst = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias)
self._in_src_feats, out_feats * num_heads, bias=bias
)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
......@@ -171,11 +179,12 @@ class GATv2Conv(nn.Module):
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias)
self._in_dst_feats, num_heads * out_feats, bias=bias
)
else:
self.res_fc = Identity()
else:
self.register_buffer('res_fc', None)
self.register_buffer("res_fc", None)
self.activation = activation
self.share_weights = share_weights
self.bias = bias
......@@ -192,7 +201,7 @@ class GATv2Conv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method.
"""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
if self.bias:
nn.init.constant_(self.fc_src.bias, 0)
......@@ -256,53 +265,70 @@ class GATv2Conv(nn.Module):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if self.share_weights:
feat_dst = feat_src
else:
feat_dst = self.fc_dst(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
graph.dstdata.update({'er': feat_dst})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
graph.srcdata.update(
{"el": feat_src}
) # (num_src_edge, num_heads, out_dim)
graph.dstdata.update({"er": feat_dst})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(
graph.edata.pop("e")
) # (num_src_edge, num_heads, out_dim)
e = (
(e * self.attn).sum(dim=-1).unsqueeze(dim=2)
) # (num_edge, num_heads, 1)
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
graph.edata["a"] = self.attn_drop(
edge_softmax(graph, e)
) # (num_edge, num_heads)
# message passing
graph.update_all(fn.u_mul_e('el', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("el", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
......@@ -104,7 +104,8 @@ class GCN2Conv(nn.Module):
"""
def __init__(self,
def __init__(
self,
in_feats,
layer,
alpha=0.1,
......@@ -112,7 +113,8 @@ class GCN2Conv(nn.Module):
project_initial_features=True,
allow_zero_in_degree=False,
bias=True,
activation=None):
activation=None,
):
super().__init__()
self._in_feats = in_feats
......@@ -131,7 +133,8 @@ class GCN2Conv(nn.Module):
self.register_parameter("weight2", None)
else:
self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats))
th.Tensor(self._in_feats, self._in_feats)
)
if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats))
......@@ -233,7 +236,7 @@ class GCN2Conv(nn.Module):
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight)
edge_weight = EdgeWeightNorm("both")(graph, edge_weight)
if edge_weight is None:
feat = feat * norm
......@@ -255,14 +258,26 @@ class GCN2Conv(nn.Module):
if self._project_initial_features:
rst = feat.add_(feat_0)
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
else:
rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta
feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
)
rst += th.addmm(
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta
feat_0,
feat_0,
self.weight2,
beta=(1 - self.beta),
alpha=self.beta,
)
if self._bias:
......
......@@ -7,6 +7,7 @@ from torch import nn
from .... import function as fn
from ....utils import expand_as_pair
class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
......@@ -45,21 +46,19 @@ class GINEConv(nn.Module):
>>> print(res.shape)
torch.Size([4, 20])
"""
def __init__(self,
apply_func=None,
init_eps=0,
learn_eps=False):
def __init__(self, apply_func=None, init_eps=0, learn_eps=False):
super(GINEConv, self).__init__()
self.apply_func = apply_func
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else:
self.register_buffer('eps', th.FloatTensor([init_eps]))
self.register_buffer("eps", th.FloatTensor([init_eps]))
def message(self, edges):
r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])}
return {"m": F.relu(edges.src["hn"] + edges.data["he"])}
def forward(self, graph, node_feat, edge_feat):
r"""Forward computation.
......@@ -89,10 +88,10 @@ class GINEConv(nn.Module):
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src
graph.edata['he'] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
graph.srcdata["hn"] = feat_src
graph.edata["he"] = edge_feat
graph.update_all(self.message, fn.sum("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
"""Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd"""
@staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn
......@@ -40,19 +43,25 @@ class InvertibleCheckpoint(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible")
raise RuntimeError(
"InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible"
)
# retrieve input and output tensor nodes
if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \
for more than once.")
raise RuntimeError(
"Trying to perform backward on the InvertibleCheckpoint \
for more than once."
)
inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop()
# reconstruct input node features
with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:]))
inputs_inverted = ctx.fn_inverse(
*((inputs[0], outputs) + inputs[2:])
)
# clear memory of outputs
outputs.storage().resize_(0)
......@@ -72,11 +81,16 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
filtered_detached_inputs = tuple(
filter(
lambda x: getattr(x, "requires_grad", False), detached_inputs
)
)
gradients = torch.autograd.grad(
outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
grad_outputs=grad_outputs,
)
input_gradients = []
i = 0
......@@ -87,7 +101,7 @@ class InvertibleCheckpoint(torch.autograd.Function):
else:
input_gradients.append(None)
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]
gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]
return (None, None, None) + gradients
......@@ -157,6 +171,7 @@ class GroupRevRes(nn.Module):
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
"""
def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList()
......@@ -173,7 +188,9 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:])
......@@ -192,13 +209,15 @@ class GroupRevRes(nn.Module):
if len(args) == 0:
args_chunks = [()] * self.groups
else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args))
chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args))
xs = []
for i in range(self.groups-1, -1, -1):
for i in range(self.groups - 1, -1, -1):
if i != 0:
y_in = ys[i-1]
y_in = ys[i - 1]
else:
y_in = sum(xs)
......@@ -232,6 +251,7 @@ class GroupRevRes(nn.Module):
self._forward,
self._inverse,
len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad])))
*(args + tuple([p for p in self.parameters() if p.requires_grad]))
)
return y
"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn
......@@ -8,6 +9,7 @@ from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax
class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer
<https://arxiv.org/abs/2003.01332>`__
......@@ -65,14 +67,17 @@ class HGTConv(nn.Module):
Examples
--------
"""
def __init__(self,
def __init__(
self,
in_size,
head_size,
num_heads,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False):
use_norm=False,
):
super().__init__()
self.in_size = in_size
self.head_size = head_size
......@@ -83,20 +88,33 @@ class HGTConv(nn.Module):
self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes)
self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes))
for i in range(num_heads)])
self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.linear_a = TypedLinear(
head_size * num_heads, head_size * num_heads, num_ntypes
)
self.relation_pri = nn.ParameterList(
[nn.Parameter(torch.ones(num_etypes)) for i in range(num_heads)]
)
self.relation_att = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.relation_msg = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout)
if use_norm:
self.norm = nn.LayerNorm(head_size * num_heads)
if in_size != head_size * num_heads:
self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads))
self.residual_w = nn.Parameter(
torch.Tensor(in_size, head_size * num_heads)
)
nn.init.xavier_uniform_(self.residual_w)
def forward(self, g, x, ntype, etype, *, presorted=False):
......@@ -125,17 +143,25 @@ class HGTConv(nn.Module):
"""
self.presorted = presorted
with g.local_scope():
k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
g.srcdata['k'] = k
g.dstdata['q'] = q
g.srcdata['v'] = v
g.edata['etype'] = etype
k = self.linear_k(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
q = self.linear_q(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
v = self.linear_v(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
g.srcdata["k"] = k
g.dstdata["q"] = q
g.srcdata["v"] = v
g.edata["etype"] = etype
g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h'))
h = g.dstdata['h'].view(-1, self.num_heads * self.head_size)
g.edata["m"] = g.edata["m"] * edge_softmax(
g, g.edata["a"]
).unsqueeze(-1)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "h"))
h = g.dstdata["h"].view(-1, self.num_heads * self.head_size)
# target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
......@@ -150,12 +176,16 @@ class HGTConv(nn.Module):
def message(self, edges):
"""Message function."""
a, m = [], []
etype = edges.data['etype']
k = torch.unbind(edges.src['k'], dim=1)
q = torch.unbind(edges.dst['q'], dim=1)
v = torch.unbind(edges.src['v'], dim=1)
etype = edges.data["etype"]
k = torch.unbind(edges.src["k"], dim=1)
q = torch.unbind(edges.dst["q"], dim=1)
v = torch.unbind(edges.src["v"], dim=1)
for i in range(self.num_heads):
kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O)
a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,)
m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O)
return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)}
a.append(
(kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d
) # (E,)
m.append(
self.relation_msg[i](v[i], etype, self.presorted)
) # (E, O)
return {"a": torch.stack(a, dim=1), "m": torch.stack(m, dim=1)}
This diff is collapsed.
......@@ -6,6 +6,7 @@ from torch import nn
from .... import function as fn
from ..linear import TypedLinear
class RelGraphConv(nn.Module):
r"""Relational graph convolution layer from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
......@@ -94,7 +95,9 @@ class RelGraphConv(nn.Module):
[-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>)
"""
def __init__(self,
def __init__(
self,
in_feat,
out_feat,
num_rels,
......@@ -104,11 +107,14 @@ class RelGraphConv(nn.Module):
activation=None,
self_loop=True,
dropout=0.0,
layer_norm=False):
layer_norm=False,
):
super().__init__()
if regularizer is not None and num_bases is None:
num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.linear_r = TypedLinear(
in_feat, out_feat, num_rels, regularizer, num_bases
)
self.bias = bias
self.activation = activation
self.self_loop = self_loop
......@@ -123,21 +129,25 @@ class RelGraphConv(nn.Module):
# the module only about graph convolution.
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
self.layer_norm_weight = nn.LayerNorm(
out_feat, elementwise_affine=True
)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout)
def message(self, edges):
"""Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted)
if 'norm' in edges.data:
m = m * edges.data['norm']
return {'m' : m}
m = self.linear_r(edges.src["h"], edges.data["etype"], self.presorted)
if "norm" in edges.data:
m = m * edges.data["norm"]
return {"m": m}
def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Forward computation.
......@@ -165,20 +175,20 @@ class RelGraphConv(nn.Module):
"""
self.presorted = presorted
with g.local_scope():
g.srcdata['h'] = feat
g.srcdata["h"] = feat
if norm is not None:
g.edata['norm'] = norm
g.edata['etype'] = etypes
g.edata["norm"] = norm
g.edata["etype"] = etypes
# message passing
g.update_all(self.message, fn.sum('m', 'h'))
g.update_all(self.message, fn.sum("m", "h"))
# apply bias and activation
h = g.dstdata['h']
h = g.dstdata["h"]
if self.layer_norm:
h = self.layer_norm_weight(h)
if self.bias:
h = h + self.h_bias
if self.self_loop:
h = h + feat[:g.num_dst_nodes()] @ self.loop_weight
h = h + feat[: g.num_dst_nodes()] @ self.loop_weight
if self.activation:
h = self.activation(h)
h = self.dropout(h)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment