Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
class EdgePredictor(nn.Module):
r"""Predictor/score function for pairs of node representations
......@@ -102,20 +103,21 @@ class EdgePredictor(nn.Module):
>>> predictor(h_src, h_dst).shape
torch.Size([3, 3])
"""
def __init__(self,
op,
in_feats=None,
out_feats=None,
bias=False):
def __init__(self, op, in_feats=None, out_feats=None, bias=False):
super(EdgePredictor, self).__init__()
assert op in ['dot', 'cos', 'ele', 'cat'], \
"Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op)
assert op in [
"dot",
"cos",
"ele",
"cat",
], "Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op)
self.op = op
if (in_feats is not None) and (out_feats is not None):
if op in ['dot', 'cos']:
if op in ["dot", "cos"]:
in_feats = 1
elif op == 'cat':
elif op == "cat":
in_feats = 2 * in_feats
self.linear = nn.Linear(in_feats, out_feats, bias=bias)
else:
......@@ -154,12 +156,12 @@ class EdgePredictor(nn.Module):
torch.Tensor
The output features.
"""
if self.op == 'dot':
if self.op == "dot":
N, D = h_src.shape
h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1)
elif self.op == 'cos':
elif self.op == "cos":
h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1)
elif self.op == 'ele':
elif self.op == "ele":
h = h_src * h_dst
else:
h = torch.cat([h_src, h_dst], dim=-1)
......
......@@ -3,6 +3,7 @@
import torch
import torch.nn as nn
class TransE(nn.Module):
r"""Similarity measure from `Translating Embeddings for Modeling Multi-relational Data
<https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html>`__
......@@ -53,6 +54,7 @@ class TransE(nn.Module):
>>> scorer(h_head, h_tail, rels).shape
torch.Size([30])
"""
def __init__(self, num_rels, feats, p=1):
super(TransE, self).__init__()
......@@ -94,4 +96,4 @@ class TransE(nn.Module):
"""
h_rel = self.rel_emb(rels)
return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
......@@ -3,6 +3,7 @@
import torch
import torch.nn as nn
class TransR(nn.Module):
r"""Similarity measure from
`Learning entity and relation embeddings for knowledge graph completion
......@@ -58,6 +59,7 @@ class TransR(nn.Module):
>>> scorer(h_head, h_tail, rels).shape
torch.Size([30])
"""
def __init__(self, num_rels, rfeats, nfeats, p=1):
super(TransR, self).__init__()
......@@ -103,4 +105,4 @@ class TransR(nn.Module):
h_head = (h_head.unsqueeze(1) @ proj_rel).squeeze(1)
h_tail = (h_tail.unsqueeze(1) @ proj_rel).squeeze(1)
return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
"""Torch NodeEmbedding."""
from datetime import timedelta
import torch as th
from ...backend import pytorch as F
from ...utils import get_shared_mem_array, create_shared_mem_array
from ...cuda import nccl
from ...partition import NDArrayPartition
from ...utils import create_shared_mem_array, get_shared_mem_array
_STORE = None
_COMM = None
class NodeEmbedding: # NodeEmbedding
'''Class for storing node embeddings.
class NodeEmbedding: # NodeEmbedding
"""Class for storing node embeddings.
The class is optimized for training large-scale node embeddings. It updates the embedding in
a sparse way and can scale to graphs with millions of nodes. It also supports partitioning
......@@ -63,15 +66,22 @@ class NodeEmbedding: # NodeEmbedding
... loss = F.sum(feats + 1, 0)
... loss.backward()
... optimizer.step()
'''
def __init__(self, num_embeddings, embedding_dim, name,
init_func=None, device=None, partition=None):
"""
def __init__(
self,
num_embeddings,
embedding_dim,
name,
init_func=None,
device=None,
partition=None,
):
global _STORE
global _COMM
if device is None:
device = th.device('cpu')
device = th.device("cpu")
# Check whether it is multi-gpu training or not.
if th.distributed.is_initialized():
......@@ -86,7 +96,7 @@ class NodeEmbedding: # NodeEmbedding
self._comm = None
self._partition = partition
host_name = '127.0.0.1'
host_name = "127.0.0.1"
port = 12346
if rank >= 0:
......@@ -94,25 +104,34 @@ class NodeEmbedding: # NodeEmbedding
# embeding status synchronization across GPU processes
if _STORE is None:
_STORE = th.distributed.TCPStore(
host_name, port, world_size, rank == 0, timedelta(seconds=10*60))
host_name,
port,
world_size,
rank == 0,
timedelta(seconds=10 * 60),
)
self._store = _STORE
# embeddings is stored in CPU memory.
if th.device(device) == th.device('cpu'):
if th.device(device) == th.device("cpu"):
if rank <= 0:
emb = create_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32)
emb = create_shared_mem_array(
name, (num_embeddings, embedding_dim), th.float32
)
if init_func is not None:
emb = init_func(emb)
if rank == 0: # the master gpu process
if rank == 0: # the master gpu process
for _ in range(1, world_size):
# send embs
self._store.set(name, name)
elif rank > 0:
# receive
self._store.wait([name])
emb = get_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32)
emb = get_shared_mem_array(
name, (num_embeddings, embedding_dim), th.float32
)
self._tensor = emb
else: # embeddings is stored in GPU memory.
else: # embeddings is stored in GPU memory.
# setup nccl communicator
if _COMM is None:
if rank < 0:
......@@ -123,11 +142,14 @@ class NodeEmbedding: # NodeEmbedding
if rank == 0:
# root process broadcasts nccl id
nccl_id = nccl.UniqueId()
self._store.set('nccl_root_id_sparse_emb', str(nccl_id))
self._store.set("nccl_root_id_sparse_emb", str(nccl_id))
else:
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id_sparse_emb'))
_COMM = nccl.Communicator(self._world_size, self._rank,
nccl_id)
nccl_id = nccl.UniqueId(
self._store.get("nccl_root_id_sparse_emb")
)
_COMM = nccl.Communicator(
self._world_size, self._rank, nccl_id
)
self._comm = _COMM
if not self._partition:
......@@ -135,14 +157,19 @@ class NodeEmbedding: # NodeEmbedding
self._partition = NDArrayPartition(
num_embeddings,
self._world_size if self._world_size > 0 else 1,
mode='remainder')
mode="remainder",
)
# create local tensors for the weights
local_size = self._partition.local_size(self._comm.rank())
# TODO(dlasalle): support 16-bit/half embeddings
emb = th.empty([local_size, embedding_dim], dtype=th.float32,
requires_grad=False, device=device)
emb = th.empty(
[local_size, embedding_dim],
dtype=th.float32,
requires_grad=False,
device=device,
)
if init_func:
emb = init_func(emb)
self._tensor = emb
......@@ -150,10 +177,10 @@ class NodeEmbedding: # NodeEmbedding
self._num_embeddings = num_embeddings
self._embedding_dim = embedding_dim
self._name = name
self._optm_state = None # track optimizer state
self._trace = [] # track minibatch
self._optm_state = None # track optimizer state
self._trace = [] # track minibatch
def __call__(self, node_ids, device=th.device('cpu')):
def __call__(self, node_ids, device=th.device("cpu")):
"""
node_ids : th.tensor
Index of the embeddings to collect.
......@@ -165,7 +192,8 @@ class NodeEmbedding: # NodeEmbedding
else:
if self.world_size > 0:
emb = self._comm.sparse_all_to_all_pull(
node_ids, self._tensor, self._partition)
node_ids, self._tensor, self._partition
)
else:
emb = self._tensor[node_ids]
emb = emb.to(device)
......@@ -331,7 +359,7 @@ class NodeEmbedding: # NodeEmbedding
return self._tensor
def all_set_embedding(self, values):
""" Set the values of the embedding. This method must be called by all
"""Set the values of the embedding. This method must be called by all
processes sharing the embedding with identical tensors for
:attr:`values`.
......@@ -346,20 +374,23 @@ class NodeEmbedding: # NodeEmbedding
if self._partition:
idxs = F.copy_to(
self._partition.get_local_indices(
self._comm.rank(),
ctx=F.context(self._tensor)),
F.context(values))
self._tensor[:] = F.copy_to(F.gather_row(values, idxs),
ctx=F.context(self._tensor))[:]
self._comm.rank(), ctx=F.context(self._tensor)
),
F.context(values),
)
self._tensor[:] = F.copy_to(
F.gather_row(values, idxs), ctx=F.context(self._tensor)
)[:]
else:
if self._rank == 0:
self._tensor[:] = F.copy_to(values,
ctx=F.context(self._tensor))[:]
self._tensor[:] = F.copy_to(
values, ctx=F.context(self._tensor)
)[:]
if th.distributed.is_initialized():
th.distributed.barrier()
def all_get_embedding(self):
""" Return a copy of the embedding stored in CPU memory. If this is a
"""Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared
memory. If the embedding is currently stored on multiple GPUs, all
processes must call this method in the same order.
......@@ -375,7 +406,7 @@ class NodeEmbedding: # NodeEmbedding
if self._partition:
if self._world_size == 0:
# non-multiprocessing
return self._tensor.to(th.device('cpu'))
return self._tensor.to(th.device("cpu"))
else:
# create a shared memory tensor
shared_name = self._name + "_gather"
......@@ -384,18 +415,23 @@ class NodeEmbedding: # NodeEmbedding
emb = create_shared_mem_array(
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype)
self._tensor.dtype,
)
self._store.set(shared_name, shared_name)
else:
self._store.wait([shared_name])
emb = get_shared_mem_array(
shared_name, (self._num_embeddings, self._embedding_dim),
self._tensor.dtype)
shared_name,
(self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
# need to map indices and slice into existing tensor
idxs = self._partition.map_to_global(
F.arange(0, self._tensor.shape[0],
ctx=F.context(self._tensor)),
self._rank).to(emb.device)
F.arange(
0, self._tensor.shape[0], ctx=F.context(self._tensor)
),
self._rank,
).to(emb.device)
emb[idxs] = self._tensor.to(emb.device)
# wait for all processes to finish
......
"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name
# pylint: disable=no-member, invalid-name
import torch as th
from torch import nn
import torch.nn.functional as F
from torch import nn
from ... import DGLGraph
from ...base import dgl_warning
from ... import function as fn
from ...base import dgl_warning
def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector.
......@@ -49,6 +51,7 @@ def matmul_maybe_select(A, B):
else:
return th.matmul(A, B)
def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm.
......@@ -92,12 +95,14 @@ def bmm_maybe_select(A, B, index):
BB = B.index_select(0, index)
return th.bmm(A.unsqueeze(1), BB).squeeze()
# pylint: disable=W0235
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self):
super(Identity, self).__init__()
......@@ -105,6 +110,7 @@ class Identity(nn.Module):
"""Return input"""
return x
class Sequential(nn.Sequential):
r"""A sequential container for stacking graph neural network modules
......@@ -220,10 +226,13 @@ class Sequential(nn.Sequential):
feats = (feats,)
feats = module(graph, *feats)
else:
raise TypeError('The first argument of forward must be a DGLGraph'
' or a list of DGLGraph s')
raise TypeError(
"The first argument of forward must be a DGLGraph"
" or a list of DGLGraph s"
)
return feats
class WeightBasis(nn.Module):
r"""Basis decomposition from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
......@@ -249,24 +258,28 @@ class WeightBasis(nn.Module):
num_outputs : int
Number of outputs.
"""
def __init__(self,
shape,
num_bases,
num_outputs):
def __init__(self, shape, num_bases, num_outputs):
super(WeightBasis, self).__init__()
self.shape = shape
self.num_bases = num_bases
self.num_outputs = num_outputs
if num_outputs <= num_bases:
dgl_warning('The number of weight outputs should be larger than the number'
' of bases.')
dgl_warning(
"The number of weight outputs should be larger than the number"
" of bases."
)
self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.w_comp, gain=nn.init.calculate_gain("relu")
)
def forward(self):
r"""Forward computation
......@@ -280,6 +293,7 @@ class WeightBasis(nn.Module):
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape)
class JumpingKnowledge(nn.Module):
r"""The Jumping Knowledge aggregation module from `Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__
......@@ -345,17 +359,25 @@ class JumpingKnowledge(nn.Module):
>>> model(feat_list).shape
torch.Size([3, 4])
"""
def __init__(self, mode='cat', in_feats=None, num_layers=None):
def __init__(self, mode="cat", in_feats=None, num_layers=None):
super(JumpingKnowledge, self).__init__()
assert mode in ['cat', 'max', 'lstm'], \
"Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
assert mode in [
"cat",
"max",
"lstm",
], "Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
self.mode = mode
if mode == 'lstm':
assert in_feats is not None, 'in_feats is required for lstm mode'
assert num_layers is not None, 'num_layers is required for lstm mode'
if mode == "lstm":
assert in_feats is not None, "in_feats is required for lstm mode"
assert (
num_layers is not None
), "num_layers is required for lstm mode"
hidden_size = (num_layers * in_feats) // 2
self.lstm = nn.LSTM(in_feats, hidden_size, bidirectional=True, batch_first=True)
self.lstm = nn.LSTM(
in_feats, hidden_size, bidirectional=True, batch_first=True
)
self.att = nn.Linear(2 * hidden_size, 1)
def reset_parameters(self):
......@@ -365,7 +387,7 @@ class JumpingKnowledge(nn.Module):
-----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode.
"""
if self.mode == 'lstm':
if self.mode == "lstm":
self.lstm.reset_parameters()
self.att.reset_parameters()
......@@ -386,18 +408,21 @@ class JumpingKnowledge(nn.Module):
Tensor
The aggregated representations.
"""
if self.mode == 'cat':
if self.mode == "cat":
return th.cat(feat_list, dim=-1)
elif self.mode == 'max':
elif self.mode == "max":
return th.stack(feat_list, dim=-1).max(dim=-1)[0]
else:
# LSTM
stacked_feat_list = th.stack(feat_list, dim=1) # (N, num_layers, in_feats)
stacked_feat_list = th.stack(
feat_list, dim=1
) # (N, num_layers, in_feats)
alpha, _ = self.lstm(stacked_feat_list)
alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)
class LabelPropagation(nn.Module):
r"""Label Propagation from `Learning from Labeled and Unlabeled Data with Label
Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__
......@@ -447,7 +472,16 @@ class LabelPropagation(nn.Module):
>>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
>>> new_labels = label_propagation(g, labels, mask)
"""
def __init__(self, k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False):
def __init__(
self,
k,
alpha,
norm_type="sym",
clamp=True,
normalize=False,
reset=False,
):
super(LabelPropagation, self).__init__()
self.k = k
self.alpha = alpha
......@@ -498,21 +532,23 @@ class LabelPropagation(nn.Module):
init = (1 - self.alpha) * y
in_degs = g.in_degrees().float().clamp(min=1)
out_degs = g.out_degrees().float().clamp(min=1)
if self.norm_type == 'sym':
if self.norm_type == "sym":
norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
elif self.norm_type == 'row':
norm_i = th.pow(in_degs, -1.).to(labels.device).unsqueeze(1)
elif self.norm_type == "row":
norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1)
else:
raise ValueError(f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}")
raise ValueError(
f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}"
)
for _ in range(self.k):
g.ndata['h'] = y * norm_j if self.norm_type == 'sym' else y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = init + self.alpha * g.ndata['h'] * norm_i
g.ndata["h"] = y * norm_j if self.norm_type == "sym" else y
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
y = init + self.alpha * g.ndata["h"] * norm_i
if self.clamp:
y = y.clamp_(0., 1.)
y = y.clamp_(0.0, 1.0)
if self.normalize:
y = F.normalize(y, p=1)
if self.reset:
......
"""Package for Tensorflow-specific NN modules."""
from .conv import *
from .softmax import *
from .utils import *
from .glob import *
from .hetero import *
from .softmax import *
from .utils import *
"""TF NN conv module"""
from .gatconv import GATConv
from .relgraphconv import RelGraphConv
from .graphconv import GraphConv
from .ginconv import GINConv
from .sageconv import SAGEConv
from .sgconv import SGConv
from .appnpconv import APPNPConv
from .chebconv import ChebConv
from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv
from .gatconv import GATConv
from .ginconv import GINConv
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .sgconv import SGConv
"""TF Module for APPNPConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from .... import function as fn
......@@ -29,10 +29,7 @@ class APPNPConv(layers.Layer):
messages received by each node. Default: ``0``.
"""
def __init__(self,
k,
alpha,
edge_drop=0.):
def __init__(self, k, alpha, edge_drop=0.0):
super(APPNPConv, self).__init__()
self._k = k
self._alpha = alpha
......@@ -56,8 +53,11 @@ class APPNPConv(layers.Layer):
should be the same as input shape.
"""
with graph.local_scope():
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1, clip_value_max=np.inf)
degs = tf.clip_by_value(
tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp)
......@@ -65,12 +65,12 @@ class APPNPConv(layers.Layer):
for _ in range(self._k):
# normalization by src node
feat = feat * norm
graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop(
tf.ones(graph.number_of_edges(), 1))
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
graph.ndata["h"] = feat
graph.edata["w"] = self.edge_drop(
tf.ones(graph.number_of_edges(), 1)
)
graph.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
feat = graph.ndata.pop("h")
# normalization by dst node
feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0
......
"""Tensorflow Module for Chebyshev Spectral Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from .... import broadcast_nodes
from .... import function as fn
from ....base import dgl_warning
from .... import broadcast_nodes, function as fn
class ChebConv(layers.Layer):
......@@ -60,12 +61,9 @@ class ChebConv(layers.Layer):
[-0.2370, 3.0164]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
k,
activation=tf.nn.relu,
bias=True):
def __init__(
self, in_feats, out_feats, k, activation=tf.nn.relu, bias=True
):
super(ChebConv, self).__init__()
self._k = k
self._in_feats = in_feats
......@@ -97,33 +95,38 @@ class ChebConv(layers.Layer):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
def unnLaplacian(feat, D_invsqrt, graph):
""" Operation Feat * D^-1/2 A D^-1/2 """
graph.ndata['h'] = feat * D_invsqrt
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
return graph.ndata.pop('h') * D_invsqrt
"""Operation Feat * D^-1/2 A D^-1/2"""
graph.ndata["h"] = feat * D_invsqrt
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
return graph.ndata.pop("h") * D_invsqrt
with graph.local_scope():
in_degrees = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
in_degrees = tf.clip_by_value(
tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
D_invsqrt = tf.expand_dims(tf.pow(in_degrees, -0.5), axis=-1)
if lambda_max is None:
dgl_warning(
"lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.")
"Please use dgl.laplacian_lambda_max to compute the eigenvalues."
)
lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list):
lambda_max = tf.constant(lambda_max, dtype=tf.float32)
if lambda_max.ndim == 1:
lambda_max = tf.expand_dims(
lambda_max, axis=-1) # (B,) to (B, 1)
lambda_max, axis=-1
) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max)
re_norm = 2. / lambda_max
re_norm = 2.0 / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat
......@@ -131,14 +134,14 @@ class ChebConv(layers.Layer):
# X_1(f)
if self._k > 1:
h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = - re_norm * h + X_0 * (re_norm - 1)
X_1 = -re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1
Xt = tf.concat((Xt, X_1), 1)
# Xi(x), i = 2...k
for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i
Xt = tf.concat((Xt, X_i), 1)
X_1, X_0 = X_i, X_1
......
"""Tensorflow Module for DenseChebConv"""
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
class DenseChebConv(layers.Layer):
......@@ -30,11 +30,7 @@ class DenseChebConv(layers.Layer):
`ChebConv <https://docs.dgl.ai/api/python/nn.tensorflow.html#chebconv>`__
"""
def __init__(self,
in_feats,
out_feats,
k,
bias=True):
def __init__(self, in_feats, out_feats, k, bias=True):
super(DenseChebConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
......@@ -42,13 +38,19 @@ class DenseChebConv(layers.Layer):
# keras initializer assume last two dims as fan_in and fan_out
xinit = tf.keras.initializers.glorot_normal()
self.W = tf.Variable(initial_value=xinit(
shape=(k, in_feats, out_feats), dtype='float32'), trainable=True)
self.W = tf.Variable(
initial_value=xinit(
shape=(k, in_feats, out_feats), dtype="float32"
),
trainable=True,
)
if bias:
zeroinit = tf.keras.initializers.zeros()
self.bias = tf.Variable(initial_value=zeroinit(
shape=(out_feats), dtype='float32'), trainable=True)
self.bias = tf.Variable(
initial_value=zeroinit(shape=(out_feats), dtype="float32"),
trainable=True,
)
else:
self.bias = None
......@@ -76,9 +78,11 @@ class DenseChebConv(layers.Layer):
"""
A = adj
num_nodes = A.shape[0]
in_degree = 1 / tf.sqrt(tf.clip_by_value(tf.reduce_sum(A, 1),
clip_value_min=1,
clip_value_max=np.inf))
in_degree = 1 / tf.sqrt(
tf.clip_by_value(
tf.reduce_sum(A, 1), clip_value_min=1, clip_value_max=np.inf
)
)
D_invsqrt = tf.linalg.diag(in_degree)
I = tf.eye(num_nodes)
L = I - D_invsqrt @ A @ D_invsqrt
......@@ -97,7 +101,7 @@ class DenseChebConv(layers.Layer):
Zs = tf.stack(Z, 0) # (k, n, n)
Zh = (Zs @ tf.expand_dims(feat, axis=0) @ self.W)
Zh = Zs @ tf.expand_dims(feat, axis=0) @ self.W
Zh = tf.reduce_sum(Zh, 0)
if self.bias is not None:
......
......@@ -60,10 +60,8 @@ class EdgeConv(layers.Layer):
A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv.
"""
def __init__(self,
out_feats,
batch_norm=False,
allow_zero_in_degree=False):
def __init__(self, out_feats, batch_norm=False, allow_zero_in_degree=False):
super(EdgeConv, self).__init__()
self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree
......@@ -111,29 +109,31 @@ class EdgeConv(layers.Layer):
"""
with g.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(g.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if tf.math.count_nonzero(g.in_degrees() == 0) > 0:
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(feat, g)
g.srcdata['x'] = h_src
g.dstdata['x'] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta'))
g.edata['theta'] = self.theta(g.edata['theta'])
g.dstdata['phi'] = self.phi(g.dstdata['x'])
g.srcdata["x"] = h_src
g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x'))
g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e'))
g.apply_edges(fn.e_add_v("theta", "phi", "e"))
# for more comments on why global batch norm instead
# of batch norm within EdgeConv go to
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/edgeconv.py
g.edata['e'] = self.bn(g.edata['e'])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))
return g.dstdata['x']
g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e("e", "e"), fn.max("e", "x"))
return g.dstdata["x"]
......@@ -57,23 +57,26 @@ class GINConv(layers.Layer):
0.55207 , 1.2442873 , -0.17693758, 0.67841303, 0.8633929 ]],
dtype=float32)>
"""
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False):
def __init__(
self, apply_func, aggregator_type, init_eps=0, learn_eps=False
):
super(GINConv, self).__init__()
self.apply_func = apply_func
if aggregator_type == 'sum':
if aggregator_type == "sum":
self._reducer = fn.sum
elif aggregator_type == 'max':
elif aggregator_type == "max":
self._reducer = fn.max
elif aggregator_type == 'mean':
elif aggregator_type == "mean":
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
# to specify whether eps is trainable or not.
self.eps = tf.Variable(initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps)
self.eps = tf.Variable(
initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps
)
def call(self, graph, feat):
r"""Compute Graph Isomorphism Network layer.
......@@ -100,9 +103,9 @@ class GINConv(layers.Layer):
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u("h", "m"), self._reducer("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
......@@ -108,60 +108,84 @@ class RelGraphConv(layers.Layer):
[-0.14293689, 0.77483284],
[ 0.091169 , -0.06761569]], dtype=float32)>
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False):
def __init__(
self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False,
):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
if (
self.num_bases is None
or self.num_bases > self.num_rels
or self.num_bases < 0
):
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.low_mem = low_mem
assert layer_norm is False, 'TensorFlow currently does not support layer norm.'
assert (
layer_norm is False
), "TensorFlow currently does not support layer norm."
xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros()
if regularizer == "basis":
# add basis weights
self.weight = tf.Variable(initial_value=xinit(
shape=(self.num_bases, self.in_feat, self.out_feat),
dtype='float32'), trainable=True)
self.weight = tf.Variable(
initial_value=xinit(
shape=(self.num_bases, self.in_feat, self.out_feat),
dtype="float32",
),
trainable=True,
)
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = tf.Variable(initial_value=xinit(
shape=(self.num_rels, self.num_bases), dtype='float32'), trainable=True)
self.w_comp = tf.Variable(
initial_value=xinit(
shape=(self.num_rels, self.num_bases), dtype="float32"
),
trainable=True,
)
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases.')
"Feature size must be a multiplier of num_bases."
)
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = tf.Variable(initial_value=xinit(
shape=(self.num_rels, self.num_bases *
self.submat_in * self.submat_out),
dtype='float32'), trainable=True)
self.weight = tf.Variable(
initial_value=xinit(
shape=(
self.num_rels,
self.num_bases * self.submat_in * self.submat_out,
),
dtype="float32",
),
trainable=True,
)
# message func
self.message_func = self.bdd_message_func
else:
......@@ -169,13 +193,17 @@ class RelGraphConv(layers.Layer):
# bias
if self.bias:
self.h_bias = tf.Variable(initial_value=zeroinit(
shape=(out_feat), dtype='float32'), trainable=True)
self.h_bias = tf.Variable(
initial_value=zeroinit(shape=(out_feat), dtype="float32"),
trainable=True,
)
# weight for self loop
if self.self_loop:
self.loop_weight = tf.Variable(initial_value=xinit(
shape=(in_feat, out_feat), dtype='float32'), trainable=True)
self.loop_weight = tf.Variable(
initial_value=xinit(shape=(in_feat, out_feat), dtype="float32"),
trainable=True,
)
self.dropout = layers.Dropout(rate=dropout)
......@@ -183,64 +211,76 @@ class RelGraphConv(layers.Layer):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = tf.reshape(self.weight, (self.num_bases,
self.in_feat * self.out_feat))
weight = tf.reshape(tf.matmul(self.w_comp, weight), (
self.num_rels, self.in_feat, self.out_feat))
weight = tf.reshape(
self.weight, (self.num_bases, self.in_feat * self.out_feat)
)
weight = tf.reshape(
tf.matmul(self.w_comp, weight),
(self.num_rels, self.in_feat, self.out_feat),
)
else:
weight = self.weight
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
if edges.src["h"].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data["type"])
msg = tf.zeros([edges.src["h"].shape[0], self.out_feat])
idx = tf.range(edges.src["h"].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
loc = edges.data["type"] == etype
w = weight[etype]
src = tf.boolean_mask(edges.src['h'], loc)
src = tf.boolean_mask(edges.src["h"], loc)
sub_msg = tf.matmul(src, w)
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
msg = utils.bmm_maybe_select(
edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
edges.src["h"], weight, edges.data["type"]
)
if "norm" in edges.data:
msg = msg * edges.data["norm"]
return {"msg": msg}
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if ((edges.src['h'].dtype == tf.int64) and
len(edges.src['h'].shape) == 1):
if (edges.src["h"].dtype == tf.int64) and len(
edges.src["h"].shape
) == 1:
raise TypeError(
'Block decomposition does not allow integer ID feature.')
"Block decomposition does not allow integer ID feature."
)
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
etypes, _ = tf.unique(edges.data["type"])
msg = tf.zeros([edges.src["h"].shape[0], self.out_feat])
idx = tf.range(edges.src["h"].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = tf.reshape(self.weight[etype],
(self.num_bases, self.submat_in, self.submat_out))
src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
(-1, self.num_bases, self.submat_in))
sub_msg = tf.einsum('abc,bcd->abd', src, w)
loc = edges.data["type"] == etype
w = tf.reshape(
self.weight[etype],
(self.num_bases, self.submat_in, self.submat_out),
)
src = tf.reshape(
tf.boolean_mask(edges.src["h"], loc),
(-1, self.num_bases, self.submat_in),
)
sub_msg = tf.einsum("abc,bcd->abd", src, w)
sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
weight = tf.reshape(tf.gather(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
weight = tf.reshape(
tf.gather(self.weight, edges.data["type"]),
(-1, self.submat_in, self.submat_out),
)
node = tf.reshape(edges.src["h"], (-1, 1, self.submat_in))
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
if "norm" in edges.data:
msg = msg * edges.data["norm"]
return {"msg": msg}
def call(self, g, x, etypes, norm=None):
"""Forward computation
......@@ -265,20 +305,21 @@ class RelGraphConv(layers.Layer):
tf.Tensor
New node features.
"""
assert g.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \
assert g.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument"
)
with g.local_scope():
g.ndata['h'] = x
g.edata['type'] = tf.cast(etypes, tf.int64)
g.ndata["h"] = x
g.edata["type"] = tf.cast(etypes, tf.int64)
if norm is not None:
g.edata['norm'] = norm
g.edata["norm"] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight)
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
g.update_all(self.message_func, fn.sum(msg="msg", out="h"))
# apply bias and activation
node_repr = g.ndata['h']
node_repr = g.ndata["h"]
if self.bias:
node_repr = node_repr + self.h_bias
if self.self_loop:
......
"""tf Module for Simplifying Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0613
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from .... import function as fn
from ....base import DGLError
......@@ -83,14 +83,17 @@ class SGConv(layers.Layer):
[0.60570633, 0.520766 ],
[0.6102368 , 0.52466124]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
k=1,
cached=False,
bias=True,
norm=None,
allow_zero_in_degree=False,
):
super(SGConv, self).__init__()
self.fc = layers.Dense(out_feats, use_bias=bias)
self._cached = cached
......@@ -140,32 +143,36 @@ class SGConv(layers.Layer):
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if self._cached_h is not None:
feat = self._cached_h
else:
# compute normalization
degs = tf.clip_by_value(tf.cast(
graph.in_degrees(), tf.float32), clip_value_min=1, clip_value_max=np.inf)
degs = tf.clip_by_value(
tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
norm = tf.pow(degs, -0.5)
norm = tf.expand_dims(norm, 1)
# compute (D^-1 A^k D)^k X
for _ in range(self._k):
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
graph.ndata["h"] = feat
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
feat = graph.ndata.pop("h")
feat = feat * norm
if self.norm is not None:
......
......@@ -3,13 +3,22 @@
import tensorflow as tf
from tensorflow.keras import layers
from ...readout import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling',
'MaxPooling', 'SortPooling', 'WeightAndSum', 'GlobalAttentionPooling']
from ...readout import (
max_nodes,
mean_nodes,
softmax_nodes,
sum_nodes,
topk_nodes,
)
__all__ = [
"SumPooling",
"AvgPooling",
"MaxPooling",
"SortPooling",
"WeightAndSum",
"GlobalAttentionPooling",
]
class SumPooling(layers.Layer):
......@@ -41,8 +50,8 @@ class SumPooling(layers.Layer):
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = sum_nodes(graph, "h")
return readout
......@@ -74,8 +83,8 @@ class AvgPooling(layers.Layer):
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = mean_nodes(graph, "h")
return readout
......@@ -107,8 +116,8 @@ class MaxPooling(layers.Layer):
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata["h"] = feat
readout = max_nodes(graph, "h")
return readout
......@@ -146,10 +155,12 @@ class SortPooling(layers.Layer):
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat = tf.sort(feat, -1)
graph.ndata['h'] = feat
graph.ndata["h"] = feat
# Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, sortby=-1)[0], (
-1, self.k * feat.shape[-1]))
ret = tf.reshape(
topk_nodes(graph, "h", self.k, sortby=-1)[0],
(-1, self.k * feat.shape[-1]),
)
return ret
......@@ -194,16 +205,18 @@ class GlobalAttentionPooling(layers.Layer):
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
assert (
gate.shape[-1] == 1
), "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')
graph.ndata["gate"] = gate
gate = softmax_nodes(graph, "gate")
graph.ndata.pop("gate")
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
graph.ndata["r"] = feat * gate
readout = sum_nodes(graph, "r")
graph.ndata.pop("r")
return readout
......@@ -221,8 +234,7 @@ class WeightAndSum(layers.Layer):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = tf.keras.Sequential(
layers.Dense(1),
layers.Activation(tf.nn.sigmoid)
layers.Dense(1), layers.Activation(tf.nn.sigmoid)
)
def call(self, g, feats):
......@@ -242,8 +254,8 @@ class WeightAndSum(layers.Layer):
Representations for B molecules
"""
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
g.ndata["h"] = feats
g.ndata["w"] = self.atom_weighting(g.ndata["h"])
h_g_sum = sum_nodes(g, "h", "w")
return h_g_sum
......@@ -2,7 +2,8 @@
import tensorflow as tf
from tensorflow.keras import layers
__all__ = ['HeteroGraphConv']
__all__ = ["HeteroGraphConv"]
class HeteroGraphConv(layers.Layer):
r"""A generic module for computing convolution on heterogeneous graphs.
......@@ -125,13 +126,16 @@ class HeteroGraphConv(layers.Layer):
mods : dict[str, nn.Module]
Modules associated with every edge types.
"""
def __init__(self, mods, aggregate='sum'):
def __init__(self, mods, aggregate="sum"):
super(HeteroGraphConv, self).__init__()
self.mods = mods
# Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items():
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None)
set_allow_zero_in_degree_fn = getattr(
v, "set_allow_zero_in_degree", None
)
if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str):
......@@ -164,7 +168,7 @@ class HeteroGraphConv(layers.Layer):
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
outputs = {nty: [] for nty in g.dsttypes}
if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
......@@ -175,7 +179,8 @@ class HeteroGraphConv(layers.Layer):
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
**mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata)
else:
for stype, etype, dtype in g.canonical_etypes:
......@@ -186,7 +191,8 @@ class HeteroGraphConv(layers.Layer):
rel_graph,
(inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
**mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata)
rsts = {}
for nty, alist in outputs.items():
......@@ -194,6 +200,7 @@ class HeteroGraphConv(layers.Layer):
rsts[nty] = self.agg_fn(alist, nty)
return rsts
def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data
generated from different relations.
......@@ -210,29 +217,35 @@ def get_aggregate_fn(agg):
Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor.
"""
if agg == 'sum':
if agg == "sum":
fn = tf.reduce_sum
elif agg == 'max':
elif agg == "max":
fn = tf.reduce_max
elif agg == 'min':
elif agg == "min":
fn = tf.reduce_min
elif agg == 'mean':
elif agg == "mean":
fn = tf.reduce_mean
elif agg == 'stack':
elif agg == "stack":
fn = None # will not be called
else:
raise DGLError('Invalid cross type aggregator. Must be one of '
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg)
if agg == 'stack':
raise DGLError(
"Invalid cross type aggregator. Must be one of "
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
)
if agg == "stack":
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
return tf.stack(inputs, axis=1)
return stack_agg
else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0:
return None
stacked = tf.stack(inputs, axis=0)
return fn(stacked, axis=0)
return aggfn
"""Utilities for tf NN package"""
# pylint: disable=no-member, invalid-name
from tensorflow.keras import layers # pylint: disable=W0235
import tensorflow as tf
from tensorflow.keras import layers # pylint: disable=W0235
def matmul_maybe_select(A, B):
......@@ -91,8 +91,7 @@ def bmm_maybe_select(A, B, index):
class Identity(layers.Layer):
"""A placeholder identity operator that is argument-insensitive.
"""
"""A placeholder identity operator that is argument-insensitive."""
def call(self, x):
"""Return input"""
......
"""dgl operator module."""
from .spmm import *
from .sddmm import *
from .edge_softmax import *
from .segment import *
from .gather_mm import *
from .sddmm import *
from .segment import *
from .spmm import *
"""dgl edge_softmax operator module."""
from ..backend import astype
from ..backend import edge_softmax as edge_softmax_internal
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal
from ..backend import astype
from ..base import ALL, is_all
__all__ = ['edge_softmax']
__all__ = ["edge_softmax"]
def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
def edge_softmax(graph, logits, eids=ALL, norm_by="dst"):
r"""Compute softmax over weights of incoming edges for every node.
For a node :math:`i`, edge softmax is an operation that computes
......@@ -131,8 +131,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
if not is_all(eids):
eids = astype(eids, graph.idtype)
if graph._graph.number_of_etypes() == 1:
return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by)
return edge_softmax_internal(
graph._graph, logits, eids=eids, norm_by=norm_by
)
else:
logits_list = [None] * graph._graph.number_of_etypes()
logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}
......@@ -140,8 +141,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel]
logits_tuple = tuple(logits_list)
score_tuple = edge_softmax_hetero_internal(graph._graph,
eids, norm_by, *logits_tuple)
score_tuple = edge_softmax_hetero_internal(
graph._graph, eids, norm_by, *logits_tuple
)
score = {}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
......
"""dgl gather_mm operator module."""
from .. import backend as F
__all__ = ['gather_mm']
__all__ = ["gather_mm"]
def gather_mm(a, b, *, idx_b):
r"""Gather data according to the given indices and perform matrix multiplication.
......@@ -31,12 +32,19 @@ def gather_mm(a, b, *, idx_b):
if N > 1000000 or D1 > 8 or D2 > 8:
# Use segment_mm for large workload
import torch
sorted_idx_b, perm = torch.sort(idx_b)
_, rev_perm = torch.sort(perm)
sorted_a = torch.index_select(a, 0, perm)
pos_l = torch.searchsorted(sorted_idx_b, torch.arange(R, device=a.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(idx_b)], device=a.device)])
pos_l = torch.searchsorted(
sorted_idx_b, torch.arange(R, device=a.device)
)
pos_r = torch.cat(
[pos_l[1:], torch.tensor([len(idx_b)], device=a.device)]
)
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return torch.index_select(F.segment_mm(sorted_a, b, seglen), 0, rev_perm)
return torch.index_select(
F.segment_mm(sorted_a, b, seglen), 0, rev_perm
)
else:
return F.gather_mm(a, b, None, idx_b)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment