Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
...@@ -63,12 +63,10 @@ class DenseGraphConv(nn.Module): ...@@ -63,12 +63,10 @@ class DenseGraphConv(nn.Module):
-------- --------
`GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__ `GraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#graphconv>`__
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self, in_feats, out_feats, norm="both", bias=True, activation=None
norm='both', ):
bias=True,
activation=None):
super(DenseGraphConv, self).__init__() super(DenseGraphConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -77,7 +75,7 @@ class DenseGraphConv(nn.Module): ...@@ -77,7 +75,7 @@ class DenseGraphConv(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(th.Tensor(out_feats)) self.bias = nn.Parameter(th.Tensor(out_feats))
else: else:
self.register_buffer('bias', None) self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
self._activation = activation self._activation = activation
...@@ -114,7 +112,7 @@ class DenseGraphConv(nn.Module): ...@@ -114,7 +112,7 @@ class DenseGraphConv(nn.Module):
dst_degrees = adj.sum(dim=1).clamp(min=1) dst_degrees = adj.sum(dim=1).clamp(min=1)
feat_src = feat feat_src = feat
if self._norm == 'both': if self._norm == "both":
norm_src = th.pow(src_degrees, -0.5) norm_src = th.pow(src_degrees, -0.5)
shp = norm_src.shape + (1,) * (feat.dim() - 1) shp = norm_src.shape + (1,) * (feat.dim() - 1)
norm_src = th.reshape(norm_src, shp).to(feat.device) norm_src = th.reshape(norm_src, shp).to(feat.device)
...@@ -129,10 +127,10 @@ class DenseGraphConv(nn.Module): ...@@ -129,10 +127,10 @@ class DenseGraphConv(nn.Module):
rst = adj @ feat_src rst = adj @ feat_src
rst = th.matmul(rst, self.weight) rst = th.matmul(rst, self.weight)
if self._norm != 'none': if self._norm != "none":
if self._norm == 'both': if self._norm == "both":
norm_dst = th.pow(dst_degrees, -0.5) norm_dst = th.pow(dst_degrees, -0.5)
else: # right else: # right
norm_dst = 1.0 / dst_degrees norm_dst = 1.0 / dst_degrees
shp = norm_dst.shape + (1,) * (feat.dim() - 1) shp = norm_dst.shape + (1,) * (feat.dim() - 1)
norm_dst = th.reshape(norm_dst, shp).to(feat.device) norm_dst = th.reshape(norm_dst, shp).to(feat.device)
......
"""Torch Module for DenseSAGEConv""" """Torch Module for DenseSAGEConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from torch import nn from torch import nn
from ....utils import check_eq_shape from ....utils import check_eq_shape
...@@ -56,13 +57,16 @@ class DenseSAGEConv(nn.Module): ...@@ -56,13 +57,16 @@ class DenseSAGEConv(nn.Module):
-------- --------
`SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__ `SAGEConv <https://docs.dgl.ai/api/python/nn.pytorch.html#sageconv>`__
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
feat_drop=0., in_feats,
bias=True, out_feats,
norm=None, feat_drop=0.0,
activation=None): bias=True,
norm=None,
activation=None,
):
super(DenseSAGEConv, self).__init__() super(DenseSAGEConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -83,7 +87,7 @@ class DenseSAGEConv(nn.Module): ...@@ -83,7 +87,7 @@ class DenseSAGEConv(nn.Module):
----- -----
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
""" """
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(self.fc.weight, gain=gain) nn.init.xavier_uniform_(self.fc.weight, gain=gain)
def forward(self, adj, feat): def forward(self, adj, feat):
......
"""Torch Module for Directional Graph Networks Convolution Layer""" """Torch Module for Directional Graph Networks Convolution Layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower from .pnaconv import AGGREGATORS, SCALERS, PNAConv, PNAConvTower
def aggregate_dir_av(h, eig_s, eig_d, eig_idx): def aggregate_dir_av(h, eig_s, eig_d, eig_idx):
"""directional average aggregation""" """directional average aggregation"""
h_mod = torch.mul(h, ( h_mod = torch.mul(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) / h,
(torch.sum(torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]), (
keepdim=True, dim=1) + 1e-30)).unsqueeze(-1)) torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
/ (
torch.sum(
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
)
).unsqueeze(-1),
)
return torch.sum(h_mod, dim=1) return torch.sum(h_mod, dim=1)
def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx): def aggregate_dir_dx(h, eig_s, eig_d, h_in, eig_idx):
"""directional derivative aggregation""" """directional derivative aggregation"""
eig_w = (( eig_w = (
eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]) / (eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx])
(torch.sum( / (
torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]), torch.sum(
keepdim=True, dim=1) + 1e-30 torch.abs(eig_s[:, :, eig_idx] - eig_d[:, :, eig_idx]),
keepdim=True,
dim=1,
)
+ 1e-30
) )
).unsqueeze(-1) ).unsqueeze(-1)
h_mod = torch.mul(h, eig_w) h_mod = torch.mul(h, eig_w)
return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in) return torch.abs(torch.sum(h_mod, dim=1) - torch.sum(eig_w, dim=1) * h_in)
for k in range(1, 4): for k in range(1, 4):
AGGREGATORS[f'dir{k}-av'] = partial(aggregate_dir_av, eig_idx=k-1) AGGREGATORS[f"dir{k}-av"] = partial(aggregate_dir_av, eig_idx=k - 1)
AGGREGATORS[f'dir{k}-dx'] = partial(aggregate_dir_dx, eig_idx=k-1) AGGREGATORS[f"dir{k}-dx"] = partial(aggregate_dir_dx, eig_idx=k - 1)
class DGNConvTower(PNAConvTower): class DGNConvTower(PNAConvTower):
"""A single DGN tower with modified reduce function""" """A single DGN tower with modified reduce function"""
def message(self, edges): def message(self, edges):
"""message function for DGN layer""" """message function for DGN layer"""
if self.edge_feat_size > 0: if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1) f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else: else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1) f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {'msg': self.M(f), 'eig_s': edges.src['eig'], 'eig_d': edges.dst['eig']} return {
"msg": self.M(f),
"eig_s": edges.src["eig"],
"eig_d": edges.dst["eig"],
}
def reduce_func(self, nodes): def reduce_func(self, nodes):
"""reduce function for DGN layer""" """reduce function for DGN layer"""
h_in = nodes.data['h'] h_in = nodes.data["h"]
eig_s = nodes.mailbox['eig_s'] eig_s = nodes.mailbox["eig_s"]
eig_d = nodes.mailbox['eig_d'] eig_d = nodes.mailbox["eig_d"]
msg = nodes.mailbox['msg'] msg = nodes.mailbox["msg"]
degree = msg.size(1) degree = msg.size(1)
h = [] h = []
for agg in self.aggregators: for agg in self.aggregators:
if agg.startswith('dir'): if agg.startswith("dir"):
if agg.endswith('av'): if agg.endswith("av"):
h.append(AGGREGATORS[agg](msg, eig_s, eig_d)) h.append(AGGREGATORS[agg](msg, eig_s, eig_d))
else: else:
h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in)) h.append(AGGREGATORS[agg](msg, eig_s, eig_d, h_in))
else: else:
h.append(AGGREGATORS[agg](msg)) h.append(AGGREGATORS[agg](msg))
h = torch.cat(h, dim=1) h = torch.cat(h, dim=1)
h = torch.cat([ h = torch.cat(
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h [
for scaler in self.scalers SCALERS[scaler](h, D=degree, delta=self.delta)
], dim=1) if scaler != "identity"
return {'h_neigh': h} else h
for scaler in self.scalers
],
dim=1,
)
return {"h_neigh": h}
class DGNConv(PNAConv): class DGNConv(PNAConv):
r"""Directional Graph Network Layer from `Directional Graph Networks r"""Directional Graph Network Layer from `Directional Graph Networks
...@@ -154,24 +187,49 @@ class DGNConv(PNAConv): ...@@ -154,24 +187,49 @@ class DGNConv(PNAConv):
>>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5) >>> conv = DGNConv(10, 10, ['dir1-av', 'dir1-dx', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat, eig_vec=eig) >>> ret = conv(g, feat, eig_vec=eig)
""" """
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True): def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(DGNConv, self).__init__( super(DGNConv, self).__init__(
in_size, out_size, aggregators, scalers, delta, dropout, in_size,
num_towers, edge_feat_size, residual out_size,
aggregators,
scalers,
delta,
dropout,
num_towers,
edge_feat_size,
residual,
) )
self.towers = nn.ModuleList([ self.towers = nn.ModuleList(
DGNConvTower( [
self.tower_in_size, self.tower_out_size, DGNConvTower(
aggregators, scalers, delta, self.tower_in_size,
dropout=dropout, edge_feat_size=edge_feat_size self.tower_out_size,
) for _ in range(num_towers) aggregators,
]) scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.use_eig_vec = False self.use_eig_vec = False
for aggr in aggregators: for aggr in aggregators:
if aggr.startswith('dir'): if aggr.startswith("dir"):
self.use_eig_vec = True self.use_eig_vec = True
break break
...@@ -203,5 +261,5 @@ class DGNConv(PNAConv): ...@@ -203,5 +261,5 @@ class DGNConv(PNAConv):
""" """
with graph.local_scope(): with graph.local_scope():
if self.use_eig_vec: if self.use_eig_vec:
graph.ndata['eig'] = eig_vec graph.ndata["eig"] = eig_vec
return super().forward(graph, node_feat, edge_feat) return super().forward(graph, node_feat, edge_feat)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import torch import torch
import torch.nn as nn import torch.nn as nn
from .... import function as fn from .... import function as fn
...@@ -47,6 +48,7 @@ class EGNNConv(nn.Module): ...@@ -47,6 +48,7 @@ class EGNNConv(nn.Module):
>>> conv = EGNNConv(10, 10, 10, 2) >>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat) >>> h, x = conv(g, node_feat, coord_feat, edge_feat)
""" """
def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0): def __init__(self, in_size, hidden_size, out_size, edge_feat_size=0):
super(EGNNConv, self).__init__() super(EGNNConv, self).__init__()
...@@ -62,21 +64,21 @@ class EGNNConv(nn.Module): ...@@ -62,21 +64,21 @@ class EGNNConv(nn.Module):
nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size), nn.Linear(in_size * 2 + edge_feat_size + 1, hidden_size),
act_fn, act_fn,
nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, hidden_size),
act_fn act_fn,
) )
# \phi_h # \phi_h
self.node_mlp = nn.Sequential( self.node_mlp = nn.Sequential(
nn.Linear(in_size + hidden_size, hidden_size), nn.Linear(in_size + hidden_size, hidden_size),
act_fn, act_fn,
nn.Linear(hidden_size, out_size) nn.Linear(hidden_size, out_size),
) )
# \phi_x # \phi_x
self.coord_mlp = nn.Sequential( self.coord_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, hidden_size),
act_fn, act_fn,
nn.Linear(hidden_size, 1, bias=False) nn.Linear(hidden_size, 1, bias=False),
) )
def message(self, edges): def message(self, edges):
...@@ -84,16 +86,23 @@ class EGNNConv(nn.Module): ...@@ -84,16 +86,23 @@ class EGNNConv(nn.Module):
# concat features for edge mlp # concat features for edge mlp
if self.edge_feat_size > 0: if self.edge_feat_size > 0:
f = torch.cat( f = torch.cat(
[edges.src['h'], edges.dst['h'], edges.data['radial'], edges.data['a']], [
dim=-1 edges.src["h"],
edges.dst["h"],
edges.data["radial"],
edges.data["a"],
],
dim=-1,
) )
else: else:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['radial']], dim=-1) f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["radial"]], dim=-1
)
msg_h = self.edge_mlp(f) msg_h = self.edge_mlp(f)
msg_x = self.coord_mlp(msg_h) * edges.data['x_diff'] msg_x = self.coord_mlp(msg_h) * edges.data["x_diff"]
return {'msg_x': msg_x, 'msg_h': msg_h} return {"msg_x": msg_x, "msg_h": msg_h}
def forward(self, graph, node_feat, coord_feat, edge_feat=None): def forward(self, graph, node_feat, coord_feat, edge_feat=None):
r""" r"""
...@@ -126,27 +135,29 @@ class EGNNConv(nn.Module): ...@@ -126,27 +135,29 @@ class EGNNConv(nn.Module):
""" """
with graph.local_scope(): with graph.local_scope():
# node feature # node feature
graph.ndata['h'] = node_feat graph.ndata["h"] = node_feat
# coordinate feature # coordinate feature
graph.ndata['x'] = coord_feat graph.ndata["x"] = coord_feat
# edge feature # edge feature
if self.edge_feat_size > 0: if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided." assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat graph.edata["a"] = edge_feat
# get coordinate diff & radial features # get coordinate diff & radial features
graph.apply_edges(fn.u_sub_v('x', 'x', 'x_diff')) graph.apply_edges(fn.u_sub_v("x", "x", "x_diff"))
graph.edata['radial'] = graph.edata['x_diff'].square().sum(dim=1).unsqueeze(-1) graph.edata["radial"] = (
graph.edata["x_diff"].square().sum(dim=1).unsqueeze(-1)
)
# normalize coordinate difference # normalize coordinate difference
graph.edata['x_diff'] = graph.edata['x_diff'] / (graph.edata['radial'].sqrt() + 1e-30) graph.edata["x_diff"] = graph.edata["x_diff"] / (
graph.edata["radial"].sqrt() + 1e-30
)
graph.apply_edges(self.message) graph.apply_edges(self.message)
graph.update_all(fn.copy_e('msg_x', 'm'), fn.mean('m', 'x_neigh')) graph.update_all(fn.copy_e("msg_x", "m"), fn.mean("m", "x_neigh"))
graph.update_all(fn.copy_e('msg_h', 'm'), fn.sum('m', 'h_neigh')) graph.update_all(fn.copy_e("msg_h", "m"), fn.sum("m", "h_neigh"))
h_neigh, x_neigh = graph.ndata['h_neigh'], graph.ndata['x_neigh'] h_neigh, x_neigh = graph.ndata["h_neigh"], graph.ndata["x_neigh"]
h = self.node_mlp( h = self.node_mlp(torch.cat([node_feat, h_neigh], dim=-1))
torch.cat([node_feat, h_neigh], dim=-1)
)
x = coord_feat + x_neigh x = coord_feat + x_neigh
return h, x return h, x
...@@ -58,12 +58,7 @@ class GatedGraphConv(nn.Module): ...@@ -58,12 +58,7 @@ class GatedGraphConv(nn.Module):
0.1342, 0.0425]], grad_fn=<AddBackward0>) 0.1342, 0.0425]], grad_fn=<AddBackward0>)
""" """
def __init__(self, def __init__(self, in_feats, out_feats, n_steps, n_etypes, bias=True):
in_feats,
out_feats,
n_steps,
n_etypes,
bias=True):
super(GatedGraphConv, self).__init__() super(GatedGraphConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -87,7 +82,7 @@ class GatedGraphConv(nn.Module): ...@@ -87,7 +82,7 @@ class GatedGraphConv(nn.Module):
The model parameters are initialized using Glorot uniform initialization The model parameters are initialized using Glorot uniform initialization
and the bias is initialized to be zero. and the bias is initialized to be zero.
""" """
gain = init.calculate_gain('relu') gain = init.calculate_gain("relu")
self.gru.reset_parameters() self.gru.reset_parameters()
for linear in self.linears: for linear in self.linears:
init.xavier_normal_(linear.weight, gain=gain) init.xavier_normal_(linear.weight, gain=gain)
...@@ -134,36 +129,44 @@ class GatedGraphConv(nn.Module): ...@@ -134,36 +129,44 @@ class GatedGraphConv(nn.Module):
is the output feature size. is the output feature size.
""" """
with graph.local_scope(): with graph.local_scope():
assert graph.is_homogeneous, \ assert graph.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous " \ "not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument" "and pass in the edge type as argument"
)
if self._n_etypes != 1: if self._n_etypes != 1:
assert etypes.min() >= 0 and etypes.max() < self._n_etypes, \ assert (
"edge type indices out of range [0, {})".format( etypes.min() >= 0 and etypes.max() < self._n_etypes
self._n_etypes) ), "edge type indices out of range [0, {})".format(
self._n_etypes
)
zero_pad = feat.new_zeros( zero_pad = feat.new_zeros(
(feat.shape[0], self._out_feats - feat.shape[1])) (feat.shape[0], self._out_feats - feat.shape[1])
)
feat = th.cat([feat, zero_pad], -1) feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
if self._n_etypes == 1 and etypes is None: if self._n_etypes == 1 and etypes is None:
# Fast path when graph has only one edge type # Fast path when graph has only one edge type
graph.ndata['h'] = self.linears[0](feat) graph.ndata["h"] = self.linears[0](feat)
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop('a') # (N, D) a = graph.ndata.pop("a") # (N, D)
else: else:
graph.ndata['h'] = feat graph.ndata["h"] = feat
for i in range(self._n_etypes): for i in range(self._n_etypes):
eids = th.nonzero( eids = (
etypes == i, as_tuple=False).view(-1).type(graph.idtype) th.nonzero(etypes == i, as_tuple=False)
.view(-1)
.type(graph.idtype)
)
if len(eids) > 0: if len(eids) > 0:
graph.apply_edges( graph.apply_edges(
lambda edges: { lambda edges: {
'W_e*h': self.linears[i](edges.src['h'])}, "W_e*h": self.linears[i](edges.src["h"])
eids },
eids,
) )
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_e("W_e*h", "m"), fn.sum("m", "a"))
a = graph.ndata.pop('a') # (N, D) a = graph.ndata.pop("a") # (N, D)
feat = self.gru(a, feat) feat = self.gru(a, feat)
return feat return feat
...@@ -4,10 +4,11 @@ import torch as th ...@@ -4,10 +4,11 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from ...functional import edge_softmax
from ....base import DGLError from ....base import DGLError
from ..utils import Identity
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ...functional import edge_softmax
from ..utils import Identity
# pylint: enable=W0235 # pylint: enable=W0235
class GATv2Conv(nn.Module): class GATv2Conv(nn.Module):
...@@ -134,18 +135,21 @@ class GATv2Conv(nn.Module): ...@@ -134,18 +135,21 @@ class GATv2Conv(nn.Module):
[-1.1850, 0.1123], [-1.1850, 0.1123],
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>) [-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
num_heads, in_feats,
feat_drop=0., out_feats,
attn_drop=0., num_heads,
negative_slope=0.2, feat_drop=0.0,
residual=False, attn_drop=0.0,
activation=None, negative_slope=0.2,
allow_zero_in_degree=False, residual=False,
bias=True, activation=None,
share_weights=False): allow_zero_in_degree=False,
bias=True,
share_weights=False,
):
super(GATv2Conv, self).__init__() super(GATv2Conv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -153,17 +157,21 @@ class GATv2Conv(nn.Module): ...@@ -153,17 +157,21 @@ class GATv2Conv(nn.Module):
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = nn.Linear( self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias) self._in_src_feats, out_feats * num_heads, bias=bias
)
self.fc_dst = nn.Linear( self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=bias) self._in_dst_feats, out_feats * num_heads, bias=bias
)
else: else:
self.fc_src = nn.Linear( self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias) self._in_src_feats, out_feats * num_heads, bias=bias
)
if share_weights: if share_weights:
self.fc_dst = self.fc_src self.fc_dst = self.fc_src
else: else:
self.fc_dst = nn.Linear( self.fc_dst = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=bias) self._in_src_feats, out_feats * num_heads, bias=bias
)
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
...@@ -171,11 +179,12 @@ class GATv2Conv(nn.Module): ...@@ -171,11 +179,12 @@ class GATv2Conv(nn.Module):
if residual: if residual:
if self._in_dst_feats != out_feats * num_heads: if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear( self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias) self._in_dst_feats, num_heads * out_feats, bias=bias
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.register_buffer('res_fc', None) self.register_buffer("res_fc", None)
self.activation = activation self.activation = activation
self.share_weights = share_weights self.share_weights = share_weights
self.bias = bias self.bias = bias
...@@ -192,7 +201,7 @@ class GATv2Conv(nn.Module): ...@@ -192,7 +201,7 @@ class GATv2Conv(nn.Module):
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The attention weights are using xavier initialization method. The attention weights are using xavier initialization method.
""" """
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_src.weight, gain=gain) nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
if self.bias: if self.bias:
nn.init.constant_(self.fc_src.bias, 0) nn.init.constant_(self.fc_src.bias, 0)
...@@ -256,53 +265,70 @@ class GATv2Conv(nn.Module): ...@@ -256,53 +265,70 @@ class GATv2Conv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = self.fc_src(h_src).view( feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
if self.share_weights: if self.share_weights:
feat_dst = feat_src feat_dst = feat_src
else: else:
feat_dst = self.fc_dst(h_src).view( feat_dst = self.fc_dst(h_src).view(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()] h_dst = h_dst[: graph.number_of_dst_nodes()]
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim) graph.srcdata.update(
graph.dstdata.update({'er': feat_dst}) {"el": feat_src}
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) ) # (num_src_edge, num_heads, out_dim)
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim) graph.dstdata.update({"er": feat_dst})
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1) graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(
graph.edata.pop("e")
) # (num_src_edge, num_heads, out_dim)
e = (
(e * self.attn).sum(dim=-1).unsqueeze(dim=2)
) # (num_edge, num_heads, 1)
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads) graph.edata["a"] = self.attn_drop(
edge_softmax(graph, e)
) # (num_edge, num_heads)
# message passing # message passing
graph.update_all(fn.u_mul_e('el', 'a', 'm'), graph.update_all(fn.u_mul_e("el", "a", "m"), fn.sum("m", "ft"))
fn.sum('m', 'ft')) rst = graph.dstdata["ft"]
rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
...@@ -104,15 +104,17 @@ class GCN2Conv(nn.Module): ...@@ -104,15 +104,17 @@ class GCN2Conv(nn.Module):
""" """
def __init__(self, def __init__(
in_feats, self,
layer, in_feats,
alpha=0.1, layer,
lambda_=1, alpha=0.1,
project_initial_features=True, lambda_=1,
allow_zero_in_degree=False, project_initial_features=True,
bias=True, allow_zero_in_degree=False,
activation=None): bias=True,
activation=None,
):
super().__init__() super().__init__()
self._in_feats = in_feats self._in_feats = in_feats
...@@ -131,7 +133,8 @@ class GCN2Conv(nn.Module): ...@@ -131,7 +133,8 @@ class GCN2Conv(nn.Module):
self.register_parameter("weight2", None) self.register_parameter("weight2", None)
else: else:
self.weight2 = nn.Parameter( self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats)) th.Tensor(self._in_feats, self._in_feats)
)
if self._bias: if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats)) self.bias = nn.Parameter(th.Tensor(self._in_feats))
...@@ -233,7 +236,7 @@ class GCN2Conv(nn.Module): ...@@ -233,7 +236,7 @@ class GCN2Conv(nn.Module):
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1) norm = norm.to(feat.device).unsqueeze(1)
else: else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight) edge_weight = EdgeWeightNorm("both")(graph, edge_weight)
if edge_weight is None: if edge_weight is None:
feat = feat * norm feat = feat * norm
...@@ -255,14 +258,26 @@ class GCN2Conv(nn.Module): ...@@ -255,14 +258,26 @@ class GCN2Conv(nn.Module):
if self._project_initial_features: if self._project_initial_features:
rst = feat.add_(feat_0) rst = feat.add_(feat_0)
rst = th.addmm( rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
) )
else: else:
rst = th.addmm( rst = th.addmm(
feat, feat, self.weight1, beta=(1 - self.beta), alpha=self.beta feat,
feat,
self.weight1,
beta=(1 - self.beta),
alpha=self.beta,
) )
rst += th.addmm( rst += th.addmm(
feat_0, feat_0, self.weight2, beta=(1 - self.beta), alpha=self.beta feat_0,
feat_0,
self.weight2,
beta=(1 - self.beta),
alpha=self.beta,
) )
if self._bias: if self._bias:
......
...@@ -7,6 +7,7 @@ from torch import nn ...@@ -7,6 +7,7 @@ from torch import nn
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair from ....utils import expand_as_pair
class GINEConv(nn.Module): class GINEConv(nn.Module):
r"""Graph Isomorphism Network with Edge Features, introduced by r"""Graph Isomorphism Network with Edge Features, introduced by
`Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__ `Strategies for Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
...@@ -45,21 +46,19 @@ class GINEConv(nn.Module): ...@@ -45,21 +46,19 @@ class GINEConv(nn.Module):
>>> print(res.shape) >>> print(res.shape)
torch.Size([4, 20]) torch.Size([4, 20])
""" """
def __init__(self,
apply_func=None, def __init__(self, apply_func=None, init_eps=0, learn_eps=False):
init_eps=0,
learn_eps=False):
super(GINEConv, self).__init__() super(GINEConv, self).__init__()
self.apply_func = apply_func self.apply_func = apply_func
# to specify whether eps is trainable or not. # to specify whether eps is trainable or not.
if learn_eps: if learn_eps:
self.eps = nn.Parameter(th.FloatTensor([init_eps])) self.eps = nn.Parameter(th.FloatTensor([init_eps]))
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer("eps", th.FloatTensor([init_eps]))
def message(self, edges): def message(self, edges):
r"""User-defined Message Function""" r"""User-defined Message Function"""
return {'m': F.relu(edges.src['hn'] + edges.data['he'])} return {"m": F.relu(edges.src["hn"] + edges.data["he"])}
def forward(self, graph, node_feat, edge_feat): def forward(self, graph, node_feat, edge_feat):
r"""Forward computation. r"""Forward computation.
...@@ -89,10 +88,10 @@ class GINEConv(nn.Module): ...@@ -89,10 +88,10 @@ class GINEConv(nn.Module):
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(node_feat, graph) feat_src, feat_dst = expand_as_pair(node_feat, graph)
graph.srcdata['hn'] = feat_src graph.srcdata["hn"] = feat_src
graph.edata['he'] = edge_feat graph.edata["he"] = edge_feat
graph.update_all(self.message, fn.sum('m', 'neigh')) graph.update_all(self.message, fn.sum("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
"""Torch module for grouped reversible residual connections for GNNs""" """Torch module for grouped reversible residual connections for GNNs"""
# pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728 # pylint: disable= no-member, arguments-differ, invalid-name, C0116, R1728
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
class InvertibleCheckpoint(torch.autograd.Function): class InvertibleCheckpoint(torch.autograd.Function):
r"""Extension of torch.autograd""" r"""Extension of torch.autograd"""
@staticmethod @staticmethod
def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights): def forward(ctx, fn, fn_inverse, num_inputs, *inputs_and_weights):
ctx.fn = fn ctx.fn = fn
...@@ -40,19 +43,25 @@ class InvertibleCheckpoint(torch.autograd.Function): ...@@ -40,19 +43,25 @@ class InvertibleCheckpoint(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *grad_outputs): def backward(ctx, *grad_outputs):
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("InvertibleCheckpoint is not compatible with .grad(), \ raise RuntimeError(
please use .backward() if possible") "InvertibleCheckpoint is not compatible with .grad(), \
please use .backward() if possible"
)
# retrieve input and output tensor nodes # retrieve input and output tensor nodes
if len(ctx.outputs) == 0: if len(ctx.outputs) == 0:
raise RuntimeError("Trying to perform backward on the InvertibleCheckpoint \ raise RuntimeError(
for more than once.") "Trying to perform backward on the InvertibleCheckpoint \
for more than once."
)
inputs = ctx.inputs.pop() inputs = ctx.inputs.pop()
outputs = ctx.outputs.pop() outputs = ctx.outputs.pop()
# reconstruct input node features # reconstruct input node features
with torch.no_grad(): with torch.no_grad():
# inputs[0] is DGLGraph and inputs[1] is input node features # inputs[0] is DGLGraph and inputs[1] is input node features
inputs_inverted = ctx.fn_inverse(*((inputs[0], outputs)+inputs[2:])) inputs_inverted = ctx.fn_inverse(
*((inputs[0], outputs) + inputs[2:])
)
# clear memory of outputs # clear memory of outputs
outputs.storage().resize_(0) outputs.storage().resize_(0)
...@@ -72,11 +81,16 @@ class InvertibleCheckpoint(torch.autograd.Function): ...@@ -72,11 +81,16 @@ class InvertibleCheckpoint(torch.autograd.Function):
detached_inputs = tuple(detached_inputs) detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs) temp_output = ctx.fn(*detached_inputs)
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False), filtered_detached_inputs = tuple(
detached_inputs)) filter(
gradients = torch.autograd.grad(outputs=(temp_output,), lambda x: getattr(x, "requires_grad", False), detached_inputs
inputs=filtered_detached_inputs + ctx.weights, )
grad_outputs=grad_outputs) )
gradients = torch.autograd.grad(
outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs,
)
input_gradients = [] input_gradients = []
i = 0 i = 0
...@@ -87,7 +101,7 @@ class InvertibleCheckpoint(torch.autograd.Function): ...@@ -87,7 +101,7 @@ class InvertibleCheckpoint(torch.autograd.Function):
else: else:
input_gradients.append(None) input_gradients.append(None)
gradients = tuple(input_gradients) + gradients[-len(ctx.weights):] gradients = tuple(input_gradients) + gradients[-len(ctx.weights) :]
return (None, None, None) + gradients return (None, None, None) + gradients
...@@ -157,6 +171,7 @@ class GroupRevRes(nn.Module): ...@@ -157,6 +171,7 @@ class GroupRevRes(nn.Module):
>>> model = GroupRevRes(conv, groups) >>> model = GroupRevRes(conv, groups)
>>> out = model(g, x) >>> out = model(g, x)
""" """
def __init__(self, gnn_module, groups=2): def __init__(self, gnn_module, groups=2):
super(GroupRevRes, self).__init__() super(GroupRevRes, self).__init__()
self.gnn_modules = nn.ModuleList() self.gnn_modules = nn.ModuleList()
...@@ -173,7 +188,9 @@ class GroupRevRes(nn.Module): ...@@ -173,7 +188,9 @@ class GroupRevRes(nn.Module):
if len(args) == 0: if len(args) == 0:
args_chunks = [()] * self.groups args_chunks = [()] * self.groups
else: else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)) chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args)) args_chunks = list(zip(*chunked_args))
y_in = sum(xs[1:]) y_in = sum(xs[1:])
...@@ -192,13 +209,15 @@ class GroupRevRes(nn.Module): ...@@ -192,13 +209,15 @@ class GroupRevRes(nn.Module):
if len(args) == 0: if len(args) == 0:
args_chunks = [()] * self.groups args_chunks = [()] * self.groups
else: else:
chunked_args = list(map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)) chunked_args = list(
map(lambda arg: torch.chunk(arg, self.groups, dim=-1), args)
)
args_chunks = list(zip(*chunked_args)) args_chunks = list(zip(*chunked_args))
xs = [] xs = []
for i in range(self.groups-1, -1, -1): for i in range(self.groups - 1, -1, -1):
if i != 0: if i != 0:
y_in = ys[i-1] y_in = ys[i - 1]
else: else:
y_in = sum(xs) y_in = sum(xs)
...@@ -232,6 +251,7 @@ class GroupRevRes(nn.Module): ...@@ -232,6 +251,7 @@ class GroupRevRes(nn.Module):
self._forward, self._forward,
self._inverse, self._inverse,
len(args), len(args),
*(args + tuple([p for p in self.parameters() if p.requires_grad]))) *(args + tuple([p for p in self.parameters() if p.requires_grad]))
)
return y return y
"""Heterogeneous Graph Transformer""" """Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -8,6 +9,7 @@ from .... import function as fn ...@@ -8,6 +9,7 @@ from .... import function as fn
from ..linear import TypedLinear from ..linear import TypedLinear
from ..softmax import edge_softmax from ..softmax import edge_softmax
class HGTConv(nn.Module): class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer r"""Heterogeneous graph transformer convolution from `Heterogeneous Graph Transformer
<https://arxiv.org/abs/2003.01332>`__ <https://arxiv.org/abs/2003.01332>`__
...@@ -65,14 +67,17 @@ class HGTConv(nn.Module): ...@@ -65,14 +67,17 @@ class HGTConv(nn.Module):
Examples Examples
-------- --------
""" """
def __init__(self,
in_size, def __init__(
head_size, self,
num_heads, in_size,
num_ntypes, head_size,
num_etypes, num_heads,
dropout=0.2, num_ntypes,
use_norm=False): num_etypes,
dropout=0.2,
use_norm=False,
):
super().__init__() super().__init__()
self.in_size = in_size self.in_size = in_size
self.head_size = head_size self.head_size = head_size
...@@ -83,20 +88,33 @@ class HGTConv(nn.Module): ...@@ -83,20 +88,33 @@ class HGTConv(nn.Module):
self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes) self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes) self.linear_a = TypedLinear(
head_size * num_heads, head_size * num_heads, num_ntypes
self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes)) )
for i in range(num_heads)])
self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes) self.relation_pri = nn.ParameterList(
for i in range(num_heads)]) [nn.Parameter(torch.ones(num_etypes)) for i in range(num_heads)]
self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes) )
for i in range(num_heads)]) self.relation_att = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.relation_msg = nn.ModuleList(
[
TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)
]
)
self.skip = nn.Parameter(torch.ones(num_ntypes)) self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
if use_norm: if use_norm:
self.norm = nn.LayerNorm(head_size * num_heads) self.norm = nn.LayerNorm(head_size * num_heads)
if in_size != head_size * num_heads: if in_size != head_size * num_heads:
self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads)) self.residual_w = nn.Parameter(
torch.Tensor(in_size, head_size * num_heads)
)
nn.init.xavier_uniform_(self.residual_w) nn.init.xavier_uniform_(self.residual_w)
def forward(self, g, x, ntype, etype, *, presorted=False): def forward(self, g, x, ntype, etype, *, presorted=False):
...@@ -125,17 +143,25 @@ class HGTConv(nn.Module): ...@@ -125,17 +143,25 @@ class HGTConv(nn.Module):
""" """
self.presorted = presorted self.presorted = presorted
with g.local_scope(): with g.local_scope():
k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size) k = self.linear_k(x, ntype, presorted).view(
q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size) -1, self.num_heads, self.head_size
v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size) )
g.srcdata['k'] = k q = self.linear_q(x, ntype, presorted).view(
g.dstdata['q'] = q -1, self.num_heads, self.head_size
g.srcdata['v'] = v )
g.edata['etype'] = etype v = self.linear_v(x, ntype, presorted).view(
-1, self.num_heads, self.head_size
)
g.srcdata["k"] = k
g.dstdata["q"] = q
g.srcdata["v"] = v
g.edata["etype"] = etype
g.apply_edges(self.message) g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1) g.edata["m"] = g.edata["m"] * edge_softmax(
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h')) g, g.edata["a"]
h = g.dstdata['h'].view(-1, self.num_heads * self.head_size) ).unsqueeze(-1)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "h"))
h = g.dstdata["h"].view(-1, self.num_heads * self.head_size)
# target-specific aggregation # target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted)) h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1) alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
...@@ -150,12 +176,16 @@ class HGTConv(nn.Module): ...@@ -150,12 +176,16 @@ class HGTConv(nn.Module):
def message(self, edges): def message(self, edges):
"""Message function.""" """Message function."""
a, m = [], [] a, m = [], []
etype = edges.data['etype'] etype = edges.data["etype"]
k = torch.unbind(edges.src['k'], dim=1) k = torch.unbind(edges.src["k"], dim=1)
q = torch.unbind(edges.dst['q'], dim=1) q = torch.unbind(edges.dst["q"], dim=1)
v = torch.unbind(edges.src['v'], dim=1) v = torch.unbind(edges.src["v"], dim=1)
for i in range(self.num_heads): for i in range(self.num_heads):
kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O) kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O)
a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,) a.append(
m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O) (kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d
return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)} ) # (E,)
m.append(
self.relation_msg[i](v[i], etype, self.presorted)
) # (E, O)
return {"a": torch.stack(a, dim=1), "m": torch.stack(m, dim=1)}
...@@ -4,26 +4,32 @@ import numpy as np ...@@ -4,26 +4,32 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
def aggregate_mean(h): def aggregate_mean(h):
"""mean aggregation""" """mean aggregation"""
return torch.mean(h, dim=1) return torch.mean(h, dim=1)
def aggregate_max(h): def aggregate_max(h):
"""max aggregation""" """max aggregation"""
return torch.max(h, dim=1)[0] return torch.max(h, dim=1)[0]
def aggregate_min(h): def aggregate_min(h):
"""min aggregation""" """min aggregation"""
return torch.min(h, dim=1)[0] return torch.min(h, dim=1)[0]
def aggregate_sum(h): def aggregate_sum(h):
"""sum aggregation""" """sum aggregation"""
return torch.sum(h, dim=1) return torch.sum(h, dim=1)
def aggregate_std(h): def aggregate_std(h):
"""standard deviation aggregation""" """standard deviation aggregation"""
return torch.sqrt(aggregate_var(h) + 1e-30) return torch.sqrt(aggregate_var(h) + 1e-30)
def aggregate_var(h): def aggregate_var(h):
"""variance aggregation""" """variance aggregation"""
h_mean_squares = torch.mean(h * h, dim=1) h_mean_squares = torch.mean(h * h, dim=1)
...@@ -31,52 +37,76 @@ def aggregate_var(h): ...@@ -31,52 +37,76 @@ def aggregate_var(h):
var = torch.relu(h_mean_squares - h_mean * h_mean) var = torch.relu(h_mean_squares - h_mean * h_mean)
return var return var
def _aggregate_moment(h, n): def _aggregate_moment(h, n):
"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}""" """moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""
h_mean = torch.mean(h, dim=1, keepdim=True) h_mean = torch.mean(h, dim=1, keepdim=True)
h_n = torch.mean(torch.pow(h - h_mean, n), dim=1) h_n = torch.mean(torch.pow(h - h_mean, n), dim=1)
rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1. / n) rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + 1e-30, 1.0 / n)
return rooted_h_n return rooted_h_n
def aggregate_moment_3(h): def aggregate_moment_3(h):
"""moment aggregation with n=3""" """moment aggregation with n=3"""
return _aggregate_moment(h, n=3) return _aggregate_moment(h, n=3)
def aggregate_moment_4(h): def aggregate_moment_4(h):
"""moment aggregation with n=4""" """moment aggregation with n=4"""
return _aggregate_moment(h, n=4) return _aggregate_moment(h, n=4)
def aggregate_moment_5(h): def aggregate_moment_5(h):
"""moment aggregation with n=5""" """moment aggregation with n=5"""
return _aggregate_moment(h, n=5) return _aggregate_moment(h, n=5)
def scale_identity(h): def scale_identity(h):
"""identity scaling (no scaling operation)""" """identity scaling (no scaling operation)"""
return h return h
def scale_amplification(h, D, delta): def scale_amplification(h, D, delta):
"""amplification scaling""" """amplification scaling"""
return h * (np.log(D + 1) / delta) return h * (np.log(D + 1) / delta)
def scale_attenuation(h, D, delta): def scale_attenuation(h, D, delta):
"""attenuation scaling""" """attenuation scaling"""
return h * (delta / np.log(D + 1)) return h * (delta / np.log(D + 1))
AGGREGATORS = { AGGREGATORS = {
'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min, "mean": aggregate_mean,
'std': aggregate_std, 'var': aggregate_var, 'moment3': aggregate_moment_3, "sum": aggregate_sum,
'moment4': aggregate_moment_4, 'moment5': aggregate_moment_5 "max": aggregate_max,
"min": aggregate_min,
"std": aggregate_std,
"var": aggregate_var,
"moment3": aggregate_moment_3,
"moment4": aggregate_moment_4,
"moment5": aggregate_moment_5,
} }
SCALERS = { SCALERS = {
'identity': scale_identity, "identity": scale_identity,
'amplification': scale_amplification, "amplification": scale_amplification,
'attenuation': scale_attenuation "attenuation": scale_attenuation,
} }
class PNAConvTower(nn.Module): class PNAConvTower(nn.Module):
"""A single PNA tower in PNA layers""" """A single PNA tower in PNA layers"""
def __init__(self, in_size, out_size, aggregators, scalers,
delta, dropout=0., edge_feat_size=0): def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
edge_feat_size=0,
):
super(PNAConvTower, self).__init__() super(PNAConvTower, self).__init__()
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
...@@ -86,50 +116,63 @@ class PNAConvTower(nn.Module): ...@@ -86,50 +116,63 @@ class PNAConvTower(nn.Module):
self.edge_feat_size = edge_feat_size self.edge_feat_size = edge_feat_size
self.M = nn.Linear(2 * in_size + edge_feat_size, in_size) self.M = nn.Linear(2 * in_size + edge_feat_size, in_size)
self.U = nn.Linear((len(aggregators) * len(scalers) + 1) * in_size, out_size) self.U = nn.Linear(
(len(aggregators) * len(scalers) + 1) * in_size, out_size
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.batchnorm = nn.BatchNorm1d(out_size) self.batchnorm = nn.BatchNorm1d(out_size)
def reduce_func(self, nodes): def reduce_func(self, nodes):
"""reduce function for PNA layer: """reduce function for PNA layer:
tensordot of multiple aggregation and scaling operations""" tensordot of multiple aggregation and scaling operations"""
msg = nodes.mailbox['msg'] msg = nodes.mailbox["msg"]
degree = msg.size(1) degree = msg.size(1)
h = torch.cat([AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1) h = torch.cat(
h = torch.cat([ [AGGREGATORS[agg](msg) for agg in self.aggregators], dim=1
SCALERS[scaler](h, D=degree, delta=self.delta) if scaler != 'identity' else h )
for scaler in self.scalers h = torch.cat(
], dim=1) [
return {'h_neigh': h} SCALERS[scaler](h, D=degree, delta=self.delta)
if scaler != "identity"
else h
for scaler in self.scalers
],
dim=1,
)
return {"h_neigh": h}
def message(self, edges): def message(self, edges):
"""message function for PNA layer""" """message function for PNA layer"""
if self.edge_feat_size > 0: if self.edge_feat_size > 0:
f = torch.cat([edges.src['h'], edges.dst['h'], edges.data['a']], dim=-1) f = torch.cat(
[edges.src["h"], edges.dst["h"], edges.data["a"]], dim=-1
)
else: else:
f = torch.cat([edges.src['h'], edges.dst['h']], dim=-1) f = torch.cat([edges.src["h"], edges.dst["h"]], dim=-1)
return {'msg': self.M(f)} return {"msg": self.M(f)}
def forward(self, graph, node_feat, edge_feat=None): def forward(self, graph, node_feat, edge_feat=None):
"""compute the forward pass of a single tower in PNA convolution layer""" """compute the forward pass of a single tower in PNA convolution layer"""
# calculate graph normalization factors # calculate graph normalization factors
snorm_n = torch.cat( snorm_n = torch.cat(
[torch.ones(N, 1).to(node_feat) / N for N in graph.batch_num_nodes()], [
dim=0 torch.ones(N, 1).to(node_feat) / N
for N in graph.batch_num_nodes()
],
dim=0,
).sqrt() ).sqrt()
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = node_feat graph.ndata["h"] = node_feat
if self.edge_feat_size > 0: if self.edge_feat_size > 0:
assert edge_feat is not None, "Edge features must be provided." assert edge_feat is not None, "Edge features must be provided."
graph.edata['a'] = edge_feat graph.edata["a"] = edge_feat
graph.update_all(self.message, self.reduce_func) graph.update_all(self.message, self.reduce_func)
h = self.U( h = self.U(torch.cat([node_feat, graph.ndata["h_neigh"]], dim=-1))
torch.cat([node_feat, graph.ndata['h_neigh']], dim=-1)
)
h = h * snorm_n h = h * snorm_n
return self.dropout(self.batchnorm(h)) return self.dropout(self.batchnorm(h))
class PNAConv(nn.Module): class PNAConv(nn.Module):
r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation
for Graph Nets <https://arxiv.org/abs/2004.05718>`__ for Graph Nets <https://arxiv.org/abs/2004.05718>`__
...@@ -210,14 +253,29 @@ class PNAConv(nn.Module): ...@@ -210,14 +253,29 @@ class PNAConv(nn.Module):
>>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5) >>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5)
>>> ret = conv(g, feat) >>> ret = conv(g, feat)
""" """
def __init__(self, in_size, out_size, aggregators, scalers, delta,
dropout=0., num_towers=1, edge_feat_size=0, residual=True): def __init__(
self,
in_size,
out_size,
aggregators,
scalers,
delta,
dropout=0.0,
num_towers=1,
edge_feat_size=0,
residual=True,
):
super(PNAConv, self).__init__() super(PNAConv, self).__init__()
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
assert in_size % num_towers == 0, 'in_size must be divisible by num_towers' assert (
assert out_size % num_towers == 0, 'out_size must be divisible by num_towers' in_size % num_towers == 0
), "in_size must be divisible by num_towers"
assert (
out_size % num_towers == 0
), "out_size must be divisible by num_towers"
self.tower_in_size = in_size // num_towers self.tower_in_size = in_size // num_towers
self.tower_out_size = out_size // num_towers self.tower_out_size = out_size // num_towers
self.edge_feat_size = edge_feat_size self.edge_feat_size = edge_feat_size
...@@ -225,17 +283,23 @@ class PNAConv(nn.Module): ...@@ -225,17 +283,23 @@ class PNAConv(nn.Module):
if self.in_size != self.out_size: if self.in_size != self.out_size:
self.residual = False self.residual = False
self.towers = nn.ModuleList([ self.towers = nn.ModuleList(
PNAConvTower( [
self.tower_in_size, self.tower_out_size, PNAConvTower(
aggregators, scalers, delta, self.tower_in_size,
dropout=dropout, edge_feat_size=edge_feat_size self.tower_out_size,
) for _ in range(num_towers) aggregators,
]) scalers,
delta,
dropout=dropout,
edge_feat_size=edge_feat_size,
)
for _ in range(num_towers)
]
)
self.mixing_layer = nn.Sequential( self.mixing_layer = nn.Sequential(
nn.Linear(out_size, out_size), nn.Linear(out_size, out_size), nn.LeakyReLU()
nn.LeakyReLU()
) )
def forward(self, graph, node_feat, edge_feat=None): def forward(self, graph, node_feat, edge_feat=None):
...@@ -261,14 +325,20 @@ class PNAConv(nn.Module): ...@@ -261,14 +325,20 @@ class PNAConv(nn.Module):
The output node feature of shape :math:`(N, h_n')` where :math:`h_n'` The output node feature of shape :math:`(N, h_n')` where :math:`h_n'`
should be the same as out_size. should be the same as out_size.
""" """
h_cat = torch.cat([ h_cat = torch.cat(
tower( [
graph, tower(
node_feat[:, ti * self.tower_in_size: (ti + 1) * self.tower_in_size], graph,
edge_feat node_feat[
) :,
for ti, tower in enumerate(self.towers) ti * self.tower_in_size : (ti + 1) * self.tower_in_size,
], dim=1) ],
edge_feat,
)
for ti, tower in enumerate(self.towers)
],
dim=1,
)
h_out = self.mixing_layer(h_cat) h_out = self.mixing_layer(h_cat)
# add residual connection # add residual connection
if self.residual: if self.residual:
......
...@@ -6,6 +6,7 @@ from torch import nn ...@@ -6,6 +6,7 @@ from torch import nn
from .... import function as fn from .... import function as fn
from ..linear import TypedLinear from ..linear import TypedLinear
class RelGraphConv(nn.Module): class RelGraphConv(nn.Module):
r"""Relational graph convolution layer from `Modeling Relational Data with Graph r"""Relational graph convolution layer from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__ Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
...@@ -94,21 +95,26 @@ class RelGraphConv(nn.Module): ...@@ -94,21 +95,26 @@ class RelGraphConv(nn.Module):
[-0.4323, -0.1440], [-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>) [-0.1309, -1.0000]], grad_fn=<AddBackward0>)
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
num_rels, in_feat,
regularizer=None, out_feat,
num_bases=None, num_rels,
bias=True, regularizer=None,
activation=None, num_bases=None,
self_loop=True, bias=True,
dropout=0.0, activation=None,
layer_norm=False): self_loop=True,
dropout=0.0,
layer_norm=False,
):
super().__init__() super().__init__()
if regularizer is not None and num_bases is None: if regularizer is not None and num_bases is None:
num_bases = num_rels num_bases = num_rels
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases) self.linear_r = TypedLinear(
in_feat, out_feat, num_rels, regularizer, num_bases
)
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
...@@ -123,21 +129,25 @@ class RelGraphConv(nn.Module): ...@@ -123,21 +129,25 @@ class RelGraphConv(nn.Module):
# the module only about graph convolution. # the module only about graph convolution.
# layer norm # layer norm
if self.layer_norm: if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True) self.layer_norm_weight = nn.LayerNorm(
out_feat, elementwise_affine=True
)
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(
self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def message(self, edges): def message(self, edges):
"""Message function.""" """Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted) m = self.linear_r(edges.src["h"], edges.data["etype"], self.presorted)
if 'norm' in edges.data: if "norm" in edges.data:
m = m * edges.data['norm'] m = m * edges.data["norm"]
return {'m' : m} return {"m": m}
def forward(self, g, feat, etypes, norm=None, *, presorted=False): def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Forward computation. """Forward computation.
...@@ -165,20 +175,20 @@ class RelGraphConv(nn.Module): ...@@ -165,20 +175,20 @@ class RelGraphConv(nn.Module):
""" """
self.presorted = presorted self.presorted = presorted
with g.local_scope(): with g.local_scope():
g.srcdata['h'] = feat g.srcdata["h"] = feat
if norm is not None: if norm is not None:
g.edata['norm'] = norm g.edata["norm"] = norm
g.edata['etype'] = etypes g.edata["etype"] = etypes
# message passing # message passing
g.update_all(self.message, fn.sum('m', 'h')) g.update_all(self.message, fn.sum("m", "h"))
# apply bias and activation # apply bias and activation
h = g.dstdata['h'] h = g.dstdata["h"]
if self.layer_norm: if self.layer_norm:
h = self.layer_norm_weight(h) h = self.layer_norm_weight(h)
if self.bias: if self.bias:
h = h + self.h_bias h = h + self.h_bias
if self.self_loop: if self.self_loop:
h = h + feat[:g.num_dst_nodes()] @ self.loop_weight h = h + feat[: g.num_dst_nodes()] @ self.loop_weight
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
......
...@@ -82,14 +82,16 @@ class SGConv(nn.Module): ...@@ -82,14 +82,16 @@ class SGConv(nn.Module):
[-1.9441, -0.9343]], grad_fn=<AddmmBackward>) [-1.9441, -0.9343]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, def __init__(
in_feats, self,
out_feats, in_feats,
k=1, out_feats,
cached=False, k=1,
bias=True, cached=False,
norm=None, bias=True,
allow_zero_in_degree=False): norm=None,
allow_zero_in_degree=False,
):
super(SGConv, self).__init__() super(SGConv, self).__init__()
self.fc = nn.Linear(in_feats, out_feats, bias=bias) self.fc = nn.Linear(in_feats, out_feats, bias=bias)
self._cached = cached self._cached = cached
...@@ -170,20 +172,23 @@ class SGConv(nn.Module): ...@@ -170,20 +172,23 @@ class SGConv(nn.Module):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
msg_func = fn.copy_u("h", "m") msg_func = fn.copy_u("h", "m")
if edge_weight is not None: if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm( graph.edata["_edge_weight"] = EdgeWeightNorm("both")(
'both')(graph, edge_weight) graph, edge_weight
)
msg_func = fn.u_mul_e("h", "_edge_weight", "m") msg_func = fn.u_mul_e("h", "_edge_weight", "m")
if self._cached_h is not None: if self._cached_h is not None:
...@@ -198,10 +203,9 @@ class SGConv(nn.Module): ...@@ -198,10 +203,9 @@ class SGConv(nn.Module):
for _ in range(self._k): for _ in range(self._k):
if edge_weight is None: if edge_weight is None:
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
graph.update_all(msg_func, graph.update_all(msg_func, fn.sum("m", "h"))
fn.sum('m', 'h')) feat = graph.ndata.pop("h")
feat = graph.ndata.pop('h')
if edge_weight is None: if edge_weight is None:
feat = feat * norm feat = feat * norm
......
...@@ -57,13 +57,14 @@ class TAGConv(nn.Module): ...@@ -57,13 +57,14 @@ class TAGConv(nn.Module):
[ 0.3304, -1.9927]], grad_fn=<AddmmBackward>) [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, def __init__(
in_feats, self,
out_feats, in_feats,
k=2, out_feats,
bias=True, k=2,
activation=None, bias=True,
): activation=None,
):
super(TAGConv, self).__init__() super(TAGConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -84,7 +85,7 @@ class TAGConv(nn.Module): ...@@ -84,7 +85,7 @@ class TAGConv(nn.Module):
---- ----
The model parameters are initialized using Glorot uniform initialization. The model parameters are initialized using Glorot uniform initialization.
""" """
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.lin.weight, gain=gain) nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, graph, feat, edge_weight=None): def forward(self, graph, feat, edge_weight=None):
...@@ -114,7 +115,7 @@ class TAGConv(nn.Module): ...@@ -114,7 +115,7 @@ class TAGConv(nn.Module):
is size of output feature. is size of output feature.
""" """
with graph.local_scope(): with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous' assert graph.is_homogeneous, "Graph is not homogeneous"
if edge_weight is None: if edge_weight is None:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
...@@ -122,8 +123,9 @@ class TAGConv(nn.Module): ...@@ -122,8 +123,9 @@ class TAGConv(nn.Module):
msg_func = fn.copy_u("h", "m") msg_func = fn.copy_u("h", "m")
if edge_weight is not None: if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm( graph.edata["_edge_weight"] = EdgeWeightNorm("both")(
'both')(graph, edge_weight) graph, edge_weight
)
msg_func = fn.u_mul_e("h", "_edge_weight", "m") msg_func = fn.u_mul_e("h", "_edge_weight", "m")
# D-1/2 A D -1/2 X # D-1/2 A D -1/2 X
fstack = [feat] fstack = [feat]
...@@ -132,11 +134,10 @@ class TAGConv(nn.Module): ...@@ -132,11 +134,10 @@ class TAGConv(nn.Module):
rst = fstack[-1] * norm rst = fstack[-1] * norm
else: else:
rst = fstack[-1] rst = fstack[-1]
graph.ndata['h'] = rst graph.ndata["h"] = rst
graph.update_all(msg_func, graph.update_all(msg_func, fn.sum(msg="m", out="h"))
fn.sum(msg='m', out='h')) rst = graph.ndata["h"]
rst = graph.ndata['h']
if edge_weight is None: if edge_weight is None:
rst = rst * norm rst = rst * norm
fstack.append(rst) fstack.append(rst)
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
import torch as tc import torch as tc
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .... import function as fn from .... import function as fn
class TWIRLSConv(nn.Module): class TWIRLSConv(nn.Module):
r"""Convolution together with iteratively reweighting least squre from r"""Convolution together with iteratively reweighting least squre from
`Graph Neural Networks Inspired by Classical Iterative Algorithms `Graph Neural Networks Inspired by Classical Iterative Algorithms
...@@ -74,27 +76,28 @@ class TWIRLSConv(nn.Module): ...@@ -74,27 +76,28 @@ class TWIRLSConv(nn.Module):
torch.Size([6, 2]) torch.Size([6, 2])
""" """
def __init__(self, def __init__(
input_d, self,
output_d, input_d,
hidden_d, output_d,
prop_step, hidden_d,
num_mlp_before=1, prop_step,
num_mlp_after=1, num_mlp_before=1,
norm='none', num_mlp_after=1,
precond=True, norm="none",
alp=0, precond=True,
lam=1, alp=0,
attention=False, lam=1,
tau=0.2, attention=False,
T=-1, tau=0.2,
p=1, T=-1,
use_eta=False, p=1,
attn_bef=False, use_eta=False,
dropout=0.0, attn_bef=False,
attn_dropout=0.0, dropout=0.0,
inp_dropout=0.0, attn_dropout=0.0,
): inp_dropout=0.0,
):
super().__init__() super().__init__()
self.input_d = input_d self.input_d = input_d
...@@ -123,7 +126,10 @@ class TWIRLSConv(nn.Module): ...@@ -123,7 +126,10 @@ class TWIRLSConv(nn.Module):
# whether we can cache unfolding result # whether we can cache unfolding result
self.cacheable = ( self.cacheable = (
not self.attention) and self.num_mlp_before == 0 and self.inp_dropout <= 0 (not self.attention)
and self.num_mlp_before == 0
and self.inp_dropout <= 0
)
if self.cacheable: if self.cacheable:
self.cached_unfolding = None self.cached_unfolding = None
...@@ -136,20 +142,42 @@ class TWIRLSConv(nn.Module): ...@@ -136,20 +142,42 @@ class TWIRLSConv(nn.Module):
self.size_bef_unf = self.output_d # as the output of mlp_bef self.size_bef_unf = self.output_d # as the output of mlp_bef
# ----- computational modules ----- # ----- computational modules -----
self.mlp_bef = MLP(self.input_d, self.hidden_d, self.size_bef_unf, self.num_mlp_before, self.mlp_bef = MLP(
self.dropout, self.norm, init_activate=False) self.input_d,
self.hidden_d,
self.unfolding = TWIRLSUnfoldingAndAttention(self.hidden_d, self.alp, self.lam, self.size_bef_unf,
self.prop_step, self.attn_aft, self.tau, self.num_mlp_before,
self.T, self.p, self.use_eta, self.init_att, self.dropout,
self.attn_dropout, self.precond) self.norm,
init_activate=False,
)
self.unfolding = TWIRLSUnfoldingAndAttention(
self.hidden_d,
self.alp,
self.lam,
self.prop_step,
self.attn_aft,
self.tau,
self.T,
self.p,
self.use_eta,
self.init_att,
self.attn_dropout,
self.precond,
)
# if there are really transformations before unfolding, then do init_activate in mlp_aft # if there are really transformations before unfolding, then do init_activate in mlp_aft
self.mlp_aft = MLP(self.size_aft_unf, self.hidden_d, self.output_d, self.num_mlp_after, self.mlp_aft = MLP(
self.dropout, self.norm, self.size_aft_unf,
init_activate=(self.num_mlp_before > 0) and ( self.hidden_d,
self.num_mlp_after > 0) self.output_d,
) self.num_mlp_after,
self.dropout,
self.norm,
init_activate=(self.num_mlp_before > 0)
and (self.num_mlp_after > 0),
)
def forward(self, graph, feat): def forward(self, graph, feat):
r""" r"""
...@@ -212,8 +240,7 @@ class Propagate(nn.Module): ...@@ -212,8 +240,7 @@ class Propagate(nn.Module):
super().__init__() super().__init__()
def _prop(self, graph, Y, lam): def _prop(self, graph, Y, lam):
"""propagation part. """propagation part."""
"""
Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam) Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)
Y = AX(graph, Y) Y = AX(graph, Y)
Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam) Y = D_power_bias_X(graph, Y, -0.5, lam, 1 - lam)
...@@ -245,8 +272,11 @@ class Propagate(nn.Module): ...@@ -245,8 +272,11 @@ class Propagate(nn.Module):
Propagated feature. :math:`Z^{(k+1)}` in eq.28 in the paper. Propagated feature. :math:`Z^{(k+1)}` in eq.28 in the paper.
""" """
return (1 - alp) * Y + alp * lam * self._prop(graph, Y, lam) \ return (
(1 - alp) * Y
+ alp * lam * self._prop(graph, Y, lam)
+ alp * D_power_bias_X(graph, X, -1, lam, 1 - lam) + alp * D_power_bias_X(graph, X, -1, lam, 1 - lam)
)
class PropagateNoPrecond(nn.Module): class PropagateNoPrecond(nn.Module):
...@@ -287,7 +317,11 @@ class PropagateNoPrecond(nn.Module): ...@@ -287,7 +317,11 @@ class PropagateNoPrecond(nn.Module):
Propagated feature. :math:`Y^{(k+1)}` in eq.30 in the paper. Propagated feature. :math:`Y^{(k+1)}` in eq.30 in the paper.
""" """
return (1 - alp * lam - alp) * Y + alp * lam * normalized_AX(graph, Y) + alp * X return (
(1 - alp * lam - alp) * Y
+ alp * lam * normalized_AX(graph, Y)
+ alp * X
)
class Attention(nn.Module): class Attention(nn.Module):
...@@ -372,7 +406,7 @@ class Attention(nn.Module): ...@@ -372,7 +406,7 @@ class Attention(nn.Module):
# computing edge distance # computing edge distance
graph.srcdata["h"] = Y graph.srcdata["h"] = Y
graph.srcdata["h_norm"] = (Y ** 2).sum(-1) graph.srcdata["h_norm"] = (Y**2).sum(-1)
graph.apply_edges(fn.u_dot_v("h", "h", "dot_")) graph.apply_edges(fn.u_dot_v("h", "h", "dot_"))
graph.apply_edges(fn.u_add_v("h_norm", "h_norm", "norm_")) graph.apply_edges(fn.u_add_v("h_norm", "h_norm", "norm_"))
graph.edata["dot_"] = graph.edata["dot_"].view(-1) graph.edata["dot_"] = graph.edata["dot_"].view(-1)
...@@ -390,7 +424,8 @@ class Attention(nn.Module): ...@@ -390,7 +424,8 @@ class Attention(nn.Module):
# FIXME: consider if there is a better way # FIXME: consider if there is a better way
if self.attn_dropout > 0: if self.attn_dropout > 0:
graph.edata["w"] = F.dropout( graph.edata["w"] = F.dropout(
graph.edata["w"], self.attn_dropout, training=self.training) graph.edata["w"], self.attn_dropout, training=self.training
)
return graph return graph
...@@ -410,7 +445,8 @@ def AX(graph, X): ...@@ -410,7 +445,8 @@ def AX(graph, X):
graph.srcdata["h"] = X graph.srcdata["h"] = X
graph.update_all( graph.update_all(
fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"), fn.u_mul_e("h", "w", "m"),
fn.sum("m", "h"),
) )
Y = graph.dstdata["h"] Y = graph.dstdata["h"]
...@@ -491,20 +527,21 @@ class TWIRLSUnfoldingAndAttention(nn.Module): ...@@ -491,20 +527,21 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
""" """
def __init__(self, def __init__(
d, self,
alp, d,
lam, alp,
prop_step, lam,
attn_aft=-1, prop_step,
tau=0.2, attn_aft=-1,
T=-1, tau=0.2,
p=1, T=-1,
use_eta=False, p=1,
init_att=False, use_eta=False,
attn_dropout=0, init_att=False,
precond=True, attn_dropout=0,
): precond=True,
):
super().__init__() super().__init__()
...@@ -520,12 +557,15 @@ class TWIRLSUnfoldingAndAttention(nn.Module): ...@@ -520,12 +557,15 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
prop_method = Propagate if precond else PropagateNoPrecond prop_method = Propagate if precond else PropagateNoPrecond
self.prop_layers = nn.ModuleList( self.prop_layers = nn.ModuleList(
[prop_method() for _ in range(prop_step)]) [prop_method() for _ in range(prop_step)]
)
self.init_attn = Attention(
tau, T, p, attn_dropout) if self.init_att else None self.init_attn = (
self.attn_layer = Attention( Attention(tau, T, p, attn_dropout) if self.init_att else None
tau, T, p, attn_dropout) if self.attn_aft >= 0 else None )
self.attn_layer = (
Attention(tau, T, p, attn_dropout) if self.attn_aft >= 0 else None
)
self.etas = nn.Parameter(tc.ones(d)) if self.use_eta else None self.etas = nn.Parameter(tc.ones(d)) if self.use_eta else None
def forward(self, g, X): def forward(self, g, X):
...@@ -593,7 +633,16 @@ class MLP(nn.Module): ...@@ -593,7 +633,16 @@ class MLP(nn.Module):
""" """
def __init__(self, input_d, hidden_d, output_d, num_layers, dropout, norm, init_activate): def __init__(
self,
input_d,
hidden_d,
output_d,
num_layers,
dropout,
norm,
init_activate,
):
super().__init__() super().__init__()
self.init_activate = init_activate self.init_activate = init_activate
...@@ -611,13 +660,15 @@ class MLP(nn.Module): ...@@ -611,13 +660,15 @@ class MLP(nn.Module):
self.layers.append(nn.Linear(hidden_d, output_d)) self.layers.append(nn.Linear(hidden_d, output_d))
# how many norm layers we have # how many norm layers we have
self.norm_cnt = num_layers-1+int(init_activate) self.norm_cnt = num_layers - 1 + int(init_activate)
if norm == "batch": if norm == "batch":
self.norms = nn.ModuleList( self.norms = nn.ModuleList(
[nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)]) [nn.BatchNorm1d(hidden_d) for _ in range(self.norm_cnt)]
)
elif norm == "layer": elif norm == "layer":
self.norms = nn.ModuleList( self.norms = nn.ModuleList(
[nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)]) [nn.LayerNorm(hidden_d) for _ in range(self.norm_cnt)]
)
self.reset_params() self.reset_params()
......
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
from .gnnexplainer import GNNExplainer from .gnnexplainer import GNNExplainer
__all__ = ['GNNExplainer'] __all__ = ["GNNExplainer"]
"""Modules that transforms between graphs and between graph and tensors.""" """Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn import torch.nn as nn
from ...transforms import knn_graph, segmented_knn_graph, radius_graph
from ...transforms import knn_graph, radius_graph, segmented_knn_graph
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
''' """
x : (n_samples, n_points, dims) x : (n_samples, n_points, dims)
return : (n_samples, n_points, n_points) return : (n_samples, n_points, n_points)
''' """
x2s = (x * x).sum(-1, keepdim=True) x2s = (x * x).sum(-1, keepdim=True)
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2) return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
...@@ -58,13 +60,19 @@ class KNNGraph(nn.Module): ...@@ -58,13 +60,19 @@ class KNNGraph(nn.Module):
(tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]), (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5])) tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
""" """
def __init__(self, k): def __init__(self, k):
super(KNNGraph, self).__init__() super(KNNGraph, self).__init__()
self.k = k self.k = k
#pylint: disable=invalid-name # pylint: disable=invalid-name
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean', def forward(
exclude_self=False): self,
x,
algorithm="bruteforce-blas",
dist="euclidean",
exclude_self=False,
):
r""" r"""
Forward computation. Forward computation.
...@@ -124,8 +132,9 @@ class KNNGraph(nn.Module): ...@@ -124,8 +132,9 @@ class KNNGraph(nn.Module):
DGLGraph DGLGraph
A DGLGraph without features. A DGLGraph without features.
""" """
return knn_graph(x, self.k, algorithm=algorithm, dist=dist, return knn_graph(
exclude_self=exclude_self) x, self.k, algorithm=algorithm, dist=dist, exclude_self=exclude_self
)
class SegmentedKNNGraph(nn.Module): class SegmentedKNNGraph(nn.Module):
...@@ -172,13 +181,20 @@ class SegmentedKNNGraph(nn.Module): ...@@ -172,13 +181,20 @@ class SegmentedKNNGraph(nn.Module):
>>> >>>
""" """
def __init__(self, k): def __init__(self, k):
super(SegmentedKNNGraph, self).__init__() super(SegmentedKNNGraph, self).__init__()
self.k = k self.k = k
#pylint: disable=invalid-name # pylint: disable=invalid-name
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean', def forward(
exclude_self=False): self,
x,
segs,
algorithm="bruteforce-blas",
dist="euclidean",
exclude_self=False,
):
r"""Forward computation. r"""Forward computation.
Parameters Parameters
...@@ -240,8 +256,14 @@ class SegmentedKNNGraph(nn.Module): ...@@ -240,8 +256,14 @@ class SegmentedKNNGraph(nn.Module):
A batched DGLGraph without features. A batched DGLGraph without features.
""" """
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist, return segmented_knn_graph(
exclude_self=exclude_self) x,
self.k,
segs,
algorithm=algorithm,
dist=dist,
exclude_self=exclude_self,
)
class RadiusGraph(nn.Module): class RadiusGraph(nn.Module):
...@@ -316,16 +338,21 @@ class RadiusGraph(nn.Module): ...@@ -316,16 +338,21 @@ class RadiusGraph(nn.Module):
[0.7000], [0.7000],
[0.2828]]) [0.2828]])
""" """
#pylint: disable=invalid-name # pylint: disable=invalid-name
def __init__(self, r, p=2, self_loop=False, def __init__(
compute_mode='donot_use_mm_for_euclid_dist'): self,
r,
p=2,
self_loop=False,
compute_mode="donot_use_mm_for_euclid_dist",
):
super(RadiusGraph, self).__init__() super(RadiusGraph, self).__init__()
self.r = r self.r = r
self.p = p self.p = p
self.self_loop = self_loop self.self_loop = self_loop
self.compute_mode = compute_mode self.compute_mode = compute_mode
#pylint: disable=invalid-name # pylint: disable=invalid-name
def forward(self, x, get_distances=False): def forward(self, x, get_distances=False):
r""" r"""
Forward computation. Forward computation.
...@@ -351,5 +378,6 @@ class RadiusGraph(nn.Module): ...@@ -351,5 +378,6 @@ class RadiusGraph(nn.Module):
The distances for the edges in the constructed graph. The distances The distances for the edges in the constructed graph. The distances
are in the same order as edge IDs. are in the same order as edge IDs.
""" """
return radius_graph(x, self.r, self.p, self.self_loop, return radius_graph(
self.compute_mode, get_distances) x, self.r, self.p, self.self_loop, self.compute_mode, get_distances
)
"""Torch modules for graph global pooling.""" """Torch modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235 # pylint: disable= no-member, arguments-differ, invalid-name, W0235
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np
from ...backend import pytorch as F from ...backend import pytorch as F
from ...base import dgl_warning from ...base import dgl_warning
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\ from ...readout import (
softmax_nodes, topk_nodes broadcast_nodes,
max_nodes,
mean_nodes,
softmax_nodes,
sum_nodes,
topk_nodes,
)
__all__ = [
"SumPooling",
"AvgPooling",
"MaxPooling",
"SortPooling",
"GlobalAttentionPooling",
"Set2Set",
"SetTransformerEncoder",
"SetTransformerDecoder",
"WeightAndSum",
]
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set',
'SetTransformerEncoder', 'SetTransformerDecoder', 'WeightAndSum']
class SumPooling(nn.Module): class SumPooling(nn.Module):
r"""Apply sum pooling over the nodes in a graph. r"""Apply sum pooling over the nodes in a graph.
...@@ -67,6 +81,7 @@ class SumPooling(nn.Module): ...@@ -67,6 +81,7 @@ class SumPooling(nn.Module):
tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511], tensor([[2.2282, 1.8667, 2.4338, 1.7540, 1.4511],
[1.0608, 1.2080, 2.1780, 2.7849, 2.5420]]) [1.0608, 1.2080, 2.1780, 2.7849, 2.5420]])
""" """
def __init__(self): def __init__(self):
super(SumPooling, self).__init__() super(SumPooling, self).__init__()
...@@ -90,8 +105,8 @@ class SumPooling(nn.Module): ...@@ -90,8 +105,8 @@ class SumPooling(nn.Module):
batch size of input graphs. batch size of input graphs.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = sum_nodes(graph, 'h') readout = sum_nodes(graph, "h")
return readout return readout
...@@ -148,6 +163,7 @@ class AvgPooling(nn.Module): ...@@ -148,6 +163,7 @@ class AvgPooling(nn.Module):
tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837], tensor([[0.7427, 0.6222, 0.8113, 0.5847, 0.4837],
[0.2652, 0.3020, 0.5445, 0.6962, 0.6355]]) [0.2652, 0.3020, 0.5445, 0.6962, 0.6355]])
""" """
def __init__(self): def __init__(self):
super(AvgPooling, self).__init__() super(AvgPooling, self).__init__()
...@@ -171,8 +187,8 @@ class AvgPooling(nn.Module): ...@@ -171,8 +187,8 @@ class AvgPooling(nn.Module):
:math:`B` refers to the batch size of input graphs. :math:`B` refers to the batch size of input graphs.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = mean_nodes(graph, 'h') readout = mean_nodes(graph, "h")
return readout return readout
...@@ -229,6 +245,7 @@ class MaxPooling(nn.Module): ...@@ -229,6 +245,7 @@ class MaxPooling(nn.Module):
tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118], tensor([[0.8948, 0.9030, 0.9137, 0.7567, 0.6118],
[0.5278, 0.6365, 0.9990, 0.9028, 0.8945]]) [0.5278, 0.6365, 0.9990, 0.9028, 0.8945]])
""" """
def __init__(self): def __init__(self):
super(MaxPooling, self).__init__() super(MaxPooling, self).__init__()
...@@ -250,8 +267,8 @@ class MaxPooling(nn.Module): ...@@ -250,8 +267,8 @@ class MaxPooling(nn.Module):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = max_nodes(graph, 'h') readout = max_nodes(graph, "h")
return readout return readout
...@@ -316,6 +333,7 @@ class SortPooling(nn.Module): ...@@ -316,6 +333,7 @@ class SortPooling(nn.Module):
[0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658, [0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
0.9028]]) 0.9028]])
""" """
def __init__(self, k): def __init__(self, k):
super(SortPooling, self).__init__() super(SortPooling, self).__init__()
self.k = k self.k = k
...@@ -342,10 +360,11 @@ class SortPooling(nn.Module): ...@@ -342,10 +360,11 @@ class SortPooling(nn.Module):
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
feat, _ = feat.sort(dim=-1) feat, _ = feat.sort(dim=-1)
graph.ndata['h'] = feat graph.ndata["h"] = feat
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, sortby=-1)[0].view( ret = topk_nodes(graph, "h", self.k, sortby=-1)[0].view(
-1, self.k * feat.shape[-1]) -1, self.k * feat.shape[-1]
)
return ret return ret
...@@ -414,6 +433,7 @@ class GlobalAttentionPooling(nn.Module): ...@@ -414,6 +433,7 @@ class GlobalAttentionPooling(nn.Module):
on how to use GatedGraphConv and GlobalAttentionPooling layer to build a Graph Neural on how to use GatedGraphConv and GlobalAttentionPooling layer to build a Graph Neural
Networks that can solve Soduku. Networks that can solve Soduku.
""" """
def __init__(self, gate_nn, feat_nn=None): def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__() super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn self.gate_nn = gate_nn
...@@ -445,16 +465,18 @@ class GlobalAttentionPooling(nn.Module): ...@@ -445,16 +465,18 @@ class GlobalAttentionPooling(nn.Module):
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis." assert (
gate.shape[-1] == 1
), "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate graph.ndata["gate"] = gate
gate = softmax_nodes(graph, 'gate') gate = softmax_nodes(graph, "gate")
graph.ndata.pop('gate') graph.ndata.pop("gate")
graph.ndata['r'] = feat * gate graph.ndata["r"] = feat * gate
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, "r")
graph.ndata.pop('r') graph.ndata.pop("r")
if get_attention: if get_attention:
return readout, gate return readout, gate
...@@ -540,6 +562,7 @@ class Set2Set(nn.Module): ...@@ -540,6 +562,7 @@ class Set2Set(nn.Module):
mpnn_predictor.py>`__ mpnn_predictor.py>`__
on how to use DGL's Set2Set layer in graph property prediction applications. on how to use DGL's Set2Set layer in graph property prediction applications.
""" """
def __init__(self, input_dim, n_iters, n_layers): def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__() super(Set2Set, self).__init__()
self.input_dim = input_dim self.input_dim = input_dim
...@@ -574,8 +597,10 @@ class Set2Set(nn.Module): ...@@ -574,8 +597,10 @@ class Set2Set(nn.Module):
with graph.local_scope(): with graph.local_scope():
batch_size = graph.batch_size batch_size = graph.batch_size
h = (feat.new_zeros((self.n_layers, batch_size, self.input_dim)), h = (
feat.new_zeros((self.n_layers, batch_size, self.input_dim))) feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
feat.new_zeros((self.n_layers, batch_size, self.input_dim)),
)
q_star = feat.new_zeros(batch_size, self.output_dim) q_star = feat.new_zeros(batch_size, self.output_dim)
...@@ -583,10 +608,10 @@ class Set2Set(nn.Module): ...@@ -583,10 +608,10 @@ class Set2Set(nn.Module):
q, h = self.lstm(q_star.unsqueeze(0), h) q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(batch_size, self.input_dim) q = q.view(batch_size, self.input_dim)
e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True) e = (feat * broadcast_nodes(graph, q)).sum(dim=-1, keepdim=True)
graph.ndata['e'] = e graph.ndata["e"] = e
alpha = softmax_nodes(graph, 'e') alpha = softmax_nodes(graph, "e")
graph.ndata['r'] = feat * alpha graph.ndata["r"] = feat * alpha
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, "r")
q_star = th.cat([q, readout], dim=-1) q_star = th.cat([q, readout], dim=-1)
return q_star return q_star
...@@ -595,12 +620,12 @@ class Set2Set(nn.Module): ...@@ -595,12 +620,12 @@ class Set2Set(nn.Module):
"""Set the extra representation of the module. """Set the extra representation of the module.
which will come into effect when printing the model. which will come into effect when printing the model.
""" """
summary = 'n_iters={n_iters}' summary = "n_iters={n_iters}"
return summary.format(**self.__dict__) return summary.format(**self.__dict__)
def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y): def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" Generate binary mask array for given x and y input pairs. """Generate binary mask array for given x and y input pairs.
Parameters Parameters
---------- ----------
...@@ -620,9 +645,13 @@ def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y): ...@@ -620,9 +645,13 @@ def _gen_mask(lengths_x, lengths_y, max_len_x, max_len_y):
""" """
device = lengths_x.device device = lengths_x.device
# x_mask: (batch_size, max_len_x) # x_mask: (batch_size, max_len_x)
x_mask = th.arange(max_len_x, device=device).unsqueeze(0) < lengths_x.unsqueeze(1) x_mask = th.arange(max_len_x, device=device).unsqueeze(
0
) < lengths_x.unsqueeze(1)
# y_mask: (batch_size, max_len_y) # y_mask: (batch_size, max_len_y)
y_mask = th.arange(max_len_y, device=device).unsqueeze(0) < lengths_y.unsqueeze(1) y_mask = th.arange(max_len_y, device=device).unsqueeze(
0
) < lengths_y.unsqueeze(1)
# mask: (batch_size, 1, max_len_x, max_len_y) # mask: (batch_size, 1, max_len_x, max_len_y)
mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1) mask = (x_mask.unsqueeze(-1) & y_mask.unsqueeze(-2)).unsqueeze(1)
return mask return mask
...@@ -650,7 +679,10 @@ class MultiHeadAttention(nn.Module): ...@@ -650,7 +679,10 @@ class MultiHeadAttention(nn.Module):
----- -----
This module was used in SetTransformer layer. This module was used in SetTransformer layer.
""" """
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.d_model = d_model self.d_model = d_model
self.num_heads = num_heads self.num_heads = num_heads
...@@ -664,7 +696,7 @@ class MultiHeadAttention(nn.Module): ...@@ -664,7 +696,7 @@ class MultiHeadAttention(nn.Module):
nn.Linear(d_model, d_ff), nn.Linear(d_model, d_ff),
nn.ReLU(), nn.ReLU(),
nn.Dropout(dropouth), nn.Dropout(dropouth),
nn.Linear(d_ff, d_model) nn.Linear(d_ff, d_model),
) )
self.droph = nn.Dropout(dropouth) self.droph = nn.Dropout(dropouth)
self.dropa = nn.Dropout(dropouta) self.dropa = nn.Dropout(dropouta)
...@@ -710,25 +742,28 @@ class MultiHeadAttention(nn.Module): ...@@ -710,25 +742,28 @@ class MultiHeadAttention(nn.Module):
values = F.pad_packed_tensor(values, lengths_mem, 0) values = F.pad_packed_tensor(values, lengths_mem, 0)
# attention score with shape (B, num_heads, max_len_x, max_len_mem) # attention score with shape (B, num_heads, max_len_x, max_len_mem)
e = th.einsum('bxhd,byhd->bhxy', queries, keys) e = th.einsum("bxhd,byhd->bhxy", queries, keys)
# normalize # normalize
e = e / np.sqrt(self.d_head) e = e / np.sqrt(self.d_head)
# generate mask # generate mask
mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem) mask = _gen_mask(lengths_x, lengths_mem, max_len_x, max_len_mem)
e = e.masked_fill(mask == 0, -float('inf')) e = e.masked_fill(mask == 0, -float("inf"))
# apply softmax # apply softmax
alpha = th.softmax(e, dim=-1) alpha = th.softmax(e, dim=-1)
# the following line addresses the NaN issue, see # the following line addresses the NaN issue, see
# https://github.com/dmlc/dgl/issues/2657 # https://github.com/dmlc/dgl/issues/2657
alpha = alpha.masked_fill(mask == 0, 0.) alpha = alpha.masked_fill(mask == 0, 0.0)
# sum of value weighted by alpha # sum of value weighted by alpha
out = th.einsum('bhxy,byhd->bxhd', alpha, values) out = th.einsum("bhxy,byhd->bxhd", alpha, values)
# project to output # project to output
out = self.proj_o( out = self.proj_o(
out.contiguous().view(batch_size, max_len_x, self.num_heads * self.d_head)) out.contiguous().view(
batch_size, max_len_x, self.num_heads * self.d_head
)
)
# pack tensor # pack tensor
out = F.pack_padded_tensor(out, lengths_x) out = F.pack_padded_tensor(out, lengths_x)
...@@ -764,10 +799,19 @@ class SetAttentionBlock(nn.Module): ...@@ -764,10 +799,19 @@ class SetAttentionBlock(nn.Module):
----- -----
This module was used in SetTransformer layer. This module was used in SetTransformer layer.
""" """
def __init__(self, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(SetAttentionBlock, self).__init__() super(SetAttentionBlock, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff, self.mha = MultiHeadAttention(
dropouth=dropouth, dropouta=dropouta) d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
def forward(self, feat, lengths): def forward(self, feat, lengths):
""" """
...@@ -808,19 +852,32 @@ class InducedSetAttentionBlock(nn.Module): ...@@ -808,19 +852,32 @@ class InducedSetAttentionBlock(nn.Module):
----- -----
This module was used in SetTransformer layer. This module was used in SetTransformer layer.
""" """
def __init__(self, m, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, m, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(InducedSetAttentionBlock, self).__init__() super(InducedSetAttentionBlock, self).__init__()
self.m = m self.m = m
if m == 1: if m == 1:
dgl_warning("if m is set to 1, the parameters corresponding to query and key " dgl_warning(
"projections would not get updated during training.") "if m is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training."
)
self.d_model = d_model self.d_model = d_model
self.inducing_points = nn.Parameter( self.inducing_points = nn.Parameter(th.FloatTensor(m, d_model))
th.FloatTensor(m, d_model) self.mha = nn.ModuleList(
[
MultiHeadAttention(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
for _ in range(2)
]
) )
self.mha = nn.ModuleList([
MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta) for _ in range(2)])
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -852,8 +909,10 @@ class InducedSetAttentionBlock(nn.Module): ...@@ -852,8 +909,10 @@ class InducedSetAttentionBlock(nn.Module):
"""Set the extra representation of the module. """Set the extra representation of the module.
which will come into effect when printing the model. which will come into effect when printing the model.
""" """
shape_str = '({}, {})'.format(self.inducing_points.shape[0], self.inducing_points.shape[1]) shape_str = "({}, {})".format(
return 'InducedVector: ' + shape_str self.inducing_points.shape[0], self.inducing_points.shape[1]
)
return "InducedVector: " + shape_str
class PMALayer(nn.Module): class PMALayer(nn.Module):
...@@ -881,23 +940,32 @@ class PMALayer(nn.Module): ...@@ -881,23 +940,32 @@ class PMALayer(nn.Module):
----- -----
This module was used in SetTransformer layer. This module was used in SetTransformer layer.
""" """
def __init__(self, k, d_model, num_heads, d_head, d_ff, dropouth=0., dropouta=0.):
def __init__(
self, k, d_model, num_heads, d_head, d_ff, dropouth=0.0, dropouta=0.0
):
super(PMALayer, self).__init__() super(PMALayer, self).__init__()
self.k = k self.k = k
if k == 1: if k == 1:
dgl_warning("if k is set to 1, the parameters corresponding to query and key " dgl_warning(
"projections would not get updated during training.") "if k is set to 1, the parameters corresponding to query and key "
"projections would not get updated during training."
)
self.d_model = d_model self.d_model = d_model
self.seed_vectors = nn.Parameter( self.seed_vectors = nn.Parameter(th.FloatTensor(k, d_model))
th.FloatTensor(k, d_model) self.mha = MultiHeadAttention(
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
) )
self.mha = MultiHeadAttention(d_model, num_heads, d_head, d_ff,
dropouth=dropouth, dropouta=dropouta)
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.Linear(d_model, d_ff),
nn.ReLU(), nn.ReLU(),
nn.Dropout(dropouth), nn.Dropout(dropouth),
nn.Linear(d_ff, d_model) nn.Linear(d_ff, d_model),
) )
self.reset_parameters() self.reset_parameters()
...@@ -929,8 +997,10 @@ class PMALayer(nn.Module): ...@@ -929,8 +997,10 @@ class PMALayer(nn.Module):
"""Set the extra representation of the module. """Set the extra representation of the module.
which will come into effect when printing the model. which will come into effect when printing the model.
""" """
shape_str = '({}, {})'.format(self.seed_vectors.shape[0], self.seed_vectors.shape[1]) shape_str = "({}, {})".format(
return 'SeedVector: ' + shape_str self.seed_vectors.shape[0], self.seed_vectors.shape[1]
)
return "SeedVector: " + shape_str
class SetTransformerEncoder(nn.Module): class SetTransformerEncoder(nn.Module):
...@@ -1018,27 +1088,57 @@ class SetTransformerEncoder(nn.Module): ...@@ -1018,27 +1088,57 @@ class SetTransformerEncoder(nn.Module):
representation instead out graphwise representation, and the SetTransformerDecoder representation instead out graphwise representation, and the SetTransformerDecoder
would return a graph readout tensor. would return a graph readout tensor.
""" """
def __init__(self, d_model, n_heads, d_head, d_ff,
n_layers=1, block_type='sab', m=None, dropouth=0., dropouta=0.): def __init__(
self,
d_model,
n_heads,
d_head,
d_ff,
n_layers=1,
block_type="sab",
m=None,
dropouth=0.0,
dropouta=0.0,
):
super(SetTransformerEncoder, self).__init__() super(SetTransformerEncoder, self).__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.block_type = block_type self.block_type = block_type
self.m = m self.m = m
layers = [] layers = []
if block_type == 'isab' and m is None: if block_type == "isab" and m is None:
raise KeyError('The number of inducing points is not specified in ISAB block.') raise KeyError(
"The number of inducing points is not specified in ISAB block."
)
for _ in range(n_layers): for _ in range(n_layers):
if block_type == 'sab': if block_type == "sab":
layers.append( layers.append(
SetAttentionBlock(d_model, n_heads, d_head, d_ff, SetAttentionBlock(
dropouth=dropouth, dropouta=dropouta)) d_model,
elif block_type == 'isab': n_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
elif block_type == "isab":
layers.append( layers.append(
InducedSetAttentionBlock(m, d_model, n_heads, d_head, d_ff, InducedSetAttentionBlock(
dropouth=dropouth, dropouta=dropouta)) m,
d_model,
n_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
else: else:
raise KeyError("Unrecognized block type {}: we only support sab/isab") raise KeyError(
"Unrecognized block type {}: we only support sab/isab"
)
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
...@@ -1136,18 +1236,43 @@ class SetTransformerDecoder(nn.Module): ...@@ -1136,18 +1236,43 @@ class SetTransformerDecoder(nn.Module):
-------- --------
SetTransformerEncoder SetTransformerEncoder
""" """
def __init__(self, d_model, num_heads, d_head, d_ff, n_layers, k, dropouth=0., dropouta=0.):
def __init__(
self,
d_model,
num_heads,
d_head,
d_ff,
n_layers,
k,
dropouth=0.0,
dropouta=0.0,
):
super(SetTransformerDecoder, self).__init__() super(SetTransformerDecoder, self).__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.k = k self.k = k
self.d_model = d_model self.d_model = d_model
self.pma = PMALayer(k, d_model, num_heads, d_head, d_ff, self.pma = PMALayer(
dropouth=dropouth, dropouta=dropouta) k,
d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
layers = [] layers = []
for _ in range(n_layers): for _ in range(n_layers):
layers.append( layers.append(
SetAttentionBlock(d_model, num_heads, d_head, d_ff, SetAttentionBlock(
dropouth=dropouth, dropouta=dropouta)) d_model,
num_heads,
d_head,
d_ff,
dropouth=dropouth,
dropouta=dropouta,
)
)
self.layers = nn.ModuleList(layers) self.layers = nn.ModuleList(layers)
...@@ -1236,12 +1361,12 @@ class WeightAndSum(nn.Module): ...@@ -1236,12 +1361,12 @@ class WeightAndSum(nn.Module):
gcn_predictor.py>`__ gcn_predictor.py>`__
to understand how to use WeightAndSum layer to get the graph readout output. to understand how to use WeightAndSum layer to get the graph readout output.
""" """
def __init__(self, in_feats): def __init__(self, in_feats):
super(WeightAndSum, self).__init__() super(WeightAndSum, self).__init__()
self.in_feats = in_feats self.in_feats = in_feats
self.atom_weighting = nn.Sequential( self.atom_weighting = nn.Sequential(
nn.Linear(in_feats, 1), nn.Linear(in_feats, 1), nn.Sigmoid()
nn.Sigmoid()
) )
def forward(self, g, feats): def forward(self, g, feats):
...@@ -1261,8 +1386,8 @@ class WeightAndSum(nn.Module): ...@@ -1261,8 +1386,8 @@ class WeightAndSum(nn.Module):
Representations for B molecules Representations for B molecules
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = feats g.ndata["h"] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h']) g.ndata["w"] = self.atom_weighting(g.ndata["h"])
h_g_sum = sum_nodes(g, 'h', 'w') h_g_sum = sum_nodes(g, "h", "w")
return h_g_sum return h_g_sum
"""Heterograph NN modules""" """Heterograph NN modules"""
from functools import partial from functools import partial
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from ...base import DGLError from ...base import DGLError
__all__ = ['HeteroGraphConv', 'HeteroLinear', 'HeteroEmbedding'] __all__ = ["HeteroGraphConv", "HeteroLinear", "HeteroEmbedding"]
class HeteroGraphConv(nn.Module): class HeteroGraphConv(nn.Module):
r"""A generic module for computing convolution on heterogeneous graphs. r"""A generic module for computing convolution on heterogeneous graphs.
...@@ -120,7 +123,8 @@ class HeteroGraphConv(nn.Module): ...@@ -120,7 +123,8 @@ class HeteroGraphConv(nn.Module):
mods : dict[str, nn.Module] mods : dict[str, nn.Module]
Modules associated with every edge types. Modules associated with every edge types.
""" """
def __init__(self, mods, aggregate='sum'):
def __init__(self, mods, aggregate="sum"):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
self.mod_dict = mods self.mod_dict = mods
mods = {str(k): v for k, v in mods.items()} mods = {str(k): v for k, v in mods.items()}
...@@ -132,7 +136,9 @@ class HeteroGraphConv(nn.Module): ...@@ -132,7 +136,9 @@ class HeteroGraphConv(nn.Module):
# Do not break if graph has 0-in-degree nodes. # Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph. # Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items(): for _, v in self.mods.items():
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None) set_allow_zero_in_degree_fn = getattr(
v, "set_allow_zero_in_degree", None
)
if callable(set_allow_zero_in_degree_fn): if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True) set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str): if isinstance(aggregate, str):
...@@ -148,7 +154,7 @@ class HeteroGraphConv(nn.Module): ...@@ -148,7 +154,7 @@ class HeteroGraphConv(nn.Module):
# etype is canonical # etype is canonical
_, etype, _ = etype _, etype, _ = etype
return self.mod_dict[etype] return self.mod_dict[etype]
raise KeyError('Cannot find module with edge type %s' % etype) raise KeyError("Cannot find module with edge type %s" % etype)
def forward(self, g, inputs, mod_args=None, mod_kwargs=None): def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
"""Forward computation """Forward computation
...@@ -175,13 +181,15 @@ class HeteroGraphConv(nn.Module): ...@@ -175,13 +181,15 @@ class HeteroGraphConv(nn.Module):
mod_args = {} mod_args = {}
if mod_kwargs is None: if mod_kwargs is None:
mod_kwargs = {} mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes} outputs = {nty: [] for nty in g.dsttypes}
if isinstance(inputs, tuple) or g.is_block: if isinstance(inputs, tuple) or g.is_block:
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs src_inputs, dst_inputs = inputs
else: else:
src_inputs = inputs src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} dst_inputs = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype] rel_graph = g[stype, etype, dtype]
...@@ -191,7 +199,8 @@ class HeteroGraphConv(nn.Module): ...@@ -191,7 +199,8 @@ class HeteroGraphConv(nn.Module):
rel_graph, rel_graph,
(src_inputs[stype], dst_inputs[dtype]), (src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
...@@ -202,7 +211,8 @@ class HeteroGraphConv(nn.Module): ...@@ -202,7 +211,8 @@ class HeteroGraphConv(nn.Module):
rel_graph, rel_graph,
(inputs[stype], inputs[dtype]), (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
rsts = {} rsts = {}
for nty, alist in outputs.items(): for nty, alist in outputs.items():
...@@ -210,29 +220,36 @@ class HeteroGraphConv(nn.Module): ...@@ -210,29 +220,36 @@ class HeteroGraphConv(nn.Module):
rsts[nty] = self.agg_fn(alist, nty) rsts[nty] = self.agg_fn(alist, nty)
return rsts return rsts
def _max_reduce_func(inputs, dim): def _max_reduce_func(inputs, dim):
return th.max(inputs, dim=dim)[0] return th.max(inputs, dim=dim)[0]
def _min_reduce_func(inputs, dim): def _min_reduce_func(inputs, dim):
return th.min(inputs, dim=dim)[0] return th.min(inputs, dim=dim)[0]
def _sum_reduce_func(inputs, dim): def _sum_reduce_func(inputs, dim):
return th.sum(inputs, dim=dim) return th.sum(inputs, dim=dim)
def _mean_reduce_func(inputs, dim): def _mean_reduce_func(inputs, dim):
return th.mean(inputs, dim=dim) return th.mean(inputs, dim=dim)
def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
def _stack_agg_func(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
return th.stack(inputs, dim=1) return th.stack(inputs, dim=1)
def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
def _agg_func(inputs, dsttype, fn): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
stacked = th.stack(inputs, dim=0) stacked = th.stack(inputs, dim=0)
return fn(stacked, dim=0) return fn(stacked, dim=0)
def get_aggregate_fn(agg): def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data """Internal function to get the aggregation function for node data
generated from different relations. generated from different relations.
...@@ -249,24 +266,27 @@ def get_aggregate_fn(agg): ...@@ -249,24 +266,27 @@ def get_aggregate_fn(agg):
Aggregator function that takes a list of tensors to aggregate Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor. and returns one aggregated tensor.
""" """
if agg == 'sum': if agg == "sum":
fn = _sum_reduce_func fn = _sum_reduce_func
elif agg == 'max': elif agg == "max":
fn = _max_reduce_func fn = _max_reduce_func
elif agg == 'min': elif agg == "min":
fn = _min_reduce_func fn = _min_reduce_func
elif agg == 'mean': elif agg == "mean":
fn = _mean_reduce_func fn = _mean_reduce_func
elif agg == 'stack': elif agg == "stack":
fn = None # will not be called fn = None # will not be called
else: else:
raise DGLError('Invalid cross type aggregator. Must be one of ' raise DGLError(
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) "Invalid cross type aggregator. Must be one of "
if agg == 'stack': '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
)
if agg == "stack":
return _stack_agg_func return _stack_agg_func
else: else:
return partial(_agg_func, fn=fn) return partial(_agg_func, fn=fn)
class HeteroLinear(nn.Module): class HeteroLinear(nn.Module):
"""Apply linear transformations on heterogeneous inputs. """Apply linear transformations on heterogeneous inputs.
...@@ -294,6 +314,7 @@ class HeteroLinear(nn.Module): ...@@ -294,6 +314,7 @@ class HeteroLinear(nn.Module):
>>> print(out_feats[('user', 'follows', 'user')].shape) >>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3]) torch.Size([3, 3])
""" """
def __init__(self, in_size, out_size, bias=True): def __init__(self, in_size, out_size, bias=True):
super(HeteroLinear, self).__init__() super(HeteroLinear, self).__init__()
...@@ -320,6 +341,7 @@ class HeteroLinear(nn.Module): ...@@ -320,6 +341,7 @@ class HeteroLinear(nn.Module):
return out_feat return out_feat
class HeteroEmbedding(nn.Module): class HeteroEmbedding(nn.Module):
"""Create a heterogeneous embedding table. """Create a heterogeneous embedding table.
...@@ -356,6 +378,7 @@ class HeteroEmbedding(nn.Module): ...@@ -356,6 +378,7 @@ class HeteroEmbedding(nn.Module):
>>> print(embeds[('user', 'follows', 'user')].shape) >>> print(embeds[('user', 'follows', 'user')].shape)
torch.Size([2, 4]) torch.Size([2, 4])
""" """
def __init__(self, num_embeddings, embedding_dim): def __init__(self, num_embeddings, embedding_dim):
super(HeteroEmbedding, self).__init__() super(HeteroEmbedding, self).__init__()
...@@ -374,7 +397,9 @@ class HeteroEmbedding(nn.Module): ...@@ -374,7 +397,9 @@ class HeteroEmbedding(nn.Module):
dict[key, Tensor] dict[key, Tensor]
Heterogeneous embedding table Heterogeneous embedding table
""" """
return {self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()} return {
self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()
}
def reset_parameters(self): def reset_parameters(self):
""" """
......
"""Various commonly used linear modules""" """Various commonly used linear modules"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235 # pylint: disable= no-member, arguments-differ, invalid-name, W0235
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...ops import segment_mm, gather_mm from ...ops import gather_mm, segment_mm
__all__ = ["TypedLinear"]
__all__ = ['TypedLinear']
class TypedLinear(nn.Module): class TypedLinear(nn.Module):
r"""Linear transformation according to types. r"""Linear transformation according to types.
...@@ -81,35 +83,43 @@ class TypedLinear(nn.Module): ...@@ -81,35 +83,43 @@ class TypedLinear(nn.Module):
>>> print(y.shape) >>> print(y.shape)
torch.Size([100, 64]) torch.Size([100, 64])
""" """
def __init__(self, in_size, out_size, num_types,
regularizer=None, num_bases=None): def __init__(
self, in_size, out_size, num_types, regularizer=None, num_bases=None
):
super().__init__() super().__init__()
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
self.num_types = num_types self.num_types = num_types
if regularizer is None: if regularizer is None:
self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size)) self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))
elif regularizer == 'basis': elif regularizer == "basis":
if num_bases is None: if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.') raise ValueError(
'Missing "num_bases" for basis regularization.'
)
self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size)) self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))
self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases)) self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))
self.num_bases = num_bases self.num_bases = num_bases
elif regularizer == 'bdd': elif regularizer == "bdd":
if num_bases is None: if num_bases is None:
raise ValueError('Missing "num_bases" for bdd regularization.') raise ValueError('Missing "num_bases" for bdd regularization.')
if in_size % num_bases != 0 or out_size % num_bases != 0: if in_size % num_bases != 0 or out_size % num_bases != 0:
raise ValueError( raise ValueError(
'Input and output sizes must be divisible by num_bases.' "Input and output sizes must be divisible by num_bases."
) )
self.submat_in = in_size // num_bases self.submat_in = in_size // num_bases
self.submat_out = out_size // num_bases self.submat_out = out_size // num_bases
self.W = nn.Parameter(torch.Tensor( self.W = nn.Parameter(
num_types, num_bases * self.submat_in * self.submat_out)) torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out
)
)
self.num_bases = num_bases self.num_bases = num_bases
else: else:
raise ValueError( raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}') f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
self.regularizer = regularizer self.regularizer = regularizer
self.reset_parameters() self.reset_parameters()
...@@ -118,28 +128,46 @@ class TypedLinear(nn.Module): ...@@ -118,28 +128,46 @@ class TypedLinear(nn.Module):
with torch.no_grad(): with torch.no_grad():
# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size # Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size
if self.regularizer is None: if self.regularizer is None:
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size)) nn.init.uniform_(
elif self.regularizer == 'basis': self.W,
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size)) -1 / math.sqrt(self.in_size),
nn.init.xavier_uniform_(self.coeff, gain=nn.init.calculate_gain('relu')) 1 / math.sqrt(self.in_size),
elif self.regularizer == 'bdd': )
nn.init.uniform_(self.W, -1/math.sqrt(self.submat_in), 1/math.sqrt(self.submat_in)) elif self.regularizer == "basis":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.in_size),
1 / math.sqrt(self.in_size),
)
nn.init.xavier_uniform_(
self.coeff, gain=nn.init.calculate_gain("relu")
)
elif self.regularizer == "bdd":
nn.init.uniform_(
self.W,
-1 / math.sqrt(self.submat_in),
1 / math.sqrt(self.submat_in),
)
else: else:
raise ValueError( raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}') f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
def get_weight(self): def get_weight(self):
"""Get type-wise weight""" """Get type-wise weight"""
if self.regularizer is None: if self.regularizer is None:
return self.W return self.W
elif self.regularizer == 'basis': elif self.regularizer == "basis":
W = self.W.view(self.num_bases, self.in_size * self.out_size) W = self.W.view(self.num_bases, self.in_size * self.out_size)
return (self.coeff @ W).view(self.num_types, self.in_size, self.out_size) return (self.coeff @ W).view(
elif self.regularizer == 'bdd': self.num_types, self.in_size, self.out_size
)
elif self.regularizer == "bdd":
return self.W return self.W
else: else:
raise ValueError( raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}') f'Supported regularizer options: "basis", "bdd", but got {regularizer}'
)
def forward(self, x, x_type, sorted_by_type=False): def forward(self, x, x_type, sorted_by_type=False):
"""Forward computation. """Forward computation.
...@@ -161,23 +189,35 @@ class TypedLinear(nn.Module): ...@@ -161,23 +189,35 @@ class TypedLinear(nn.Module):
The transformed output tensor. Shape: (N, D2) The transformed output tensor. Shape: (N, D2)
""" """
w = self.get_weight() w = self.get_weight()
if self.regularizer == 'bdd': if self.regularizer == "bdd":
w = w.index_select(0, x_type).view(-1, self.submat_in, self.submat_out) w = w.index_select(0, x_type).view(
-1, self.submat_in, self.submat_out
)
x = x.view(-1, 1, self.submat_in) x = x.view(-1, 1, self.submat_in)
return torch.bmm(x, w).view(-1, self.out_size) return torch.bmm(x, w).view(-1, self.out_size)
elif sorted_by_type: elif sorted_by_type:
pos_l = torch.searchsorted(x_type, torch.arange(self.num_types, device=x.device)) pos_l = torch.searchsorted(
pos_r = torch.cat([pos_l[1:], torch.tensor([len(x_type)], device=x.device)]) x_type, torch.arange(self.num_types, device=x.device)
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize )
pos_r = torch.cat(
[pos_l[1:], torch.tensor([len(x_type)], device=x.device)]
)
seglen = (
pos_r - pos_l
).cpu() # XXX(minjie): cause device synchronize
return segment_mm(x, w, seglen_a=seglen) return segment_mm(x, w, seglen_a=seglen)
else: else:
return gather_mm(x, w, idx_b=x_type) return gather_mm(x, w, idx_b=x_type)
def __repr__(self): def __repr__(self):
if self.regularizer is None: if self.regularizer is None:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, ' return (
f'num_types={self.num_types})') f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f"num_types={self.num_types})"
)
else: else:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, ' return (
f'num_types={self.num_types}, regularizer={self.regularizer}, ' f"TypedLinear(in_size={self.in_size}, out_size={self.out_size}, "
f'num_bases={self.num_bases})') f"num_types={self.num_types}, regularizer={self.regularizer}, "
f"num_bases={self.num_bases})"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment