Unverified Commit 76bb5404 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4682)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a208e886
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class EdgePredictor(nn.Module): class EdgePredictor(nn.Module):
r"""Predictor/score function for pairs of node representations r"""Predictor/score function for pairs of node representations
...@@ -102,20 +103,21 @@ class EdgePredictor(nn.Module): ...@@ -102,20 +103,21 @@ class EdgePredictor(nn.Module):
>>> predictor(h_src, h_dst).shape >>> predictor(h_src, h_dst).shape
torch.Size([3, 3]) torch.Size([3, 3])
""" """
def __init__(self,
op, def __init__(self, op, in_feats=None, out_feats=None, bias=False):
in_feats=None,
out_feats=None,
bias=False):
super(EdgePredictor, self).__init__() super(EdgePredictor, self).__init__()
assert op in ['dot', 'cos', 'ele', 'cat'], \ assert op in [
"Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op) "dot",
"cos",
"ele",
"cat",
], "Expect op to be in ['dot', 'cos', 'ele', 'cat'], got {}".format(op)
self.op = op self.op = op
if (in_feats is not None) and (out_feats is not None): if (in_feats is not None) and (out_feats is not None):
if op in ['dot', 'cos']: if op in ["dot", "cos"]:
in_feats = 1 in_feats = 1
elif op == 'cat': elif op == "cat":
in_feats = 2 * in_feats in_feats = 2 * in_feats
self.linear = nn.Linear(in_feats, out_feats, bias=bias) self.linear = nn.Linear(in_feats, out_feats, bias=bias)
else: else:
...@@ -154,12 +156,12 @@ class EdgePredictor(nn.Module): ...@@ -154,12 +156,12 @@ class EdgePredictor(nn.Module):
torch.Tensor torch.Tensor
The output features. The output features.
""" """
if self.op == 'dot': if self.op == "dot":
N, D = h_src.shape N, D = h_src.shape
h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1) h = torch.bmm(h_src.view(N, 1, D), h_dst.view(N, D, 1)).squeeze(-1)
elif self.op == 'cos': elif self.op == "cos":
h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1) h = F.cosine_similarity(h_src, h_dst).unsqueeze(-1)
elif self.op == 'ele': elif self.op == "ele":
h = h_src * h_dst h = h_src * h_dst
else: else:
h = torch.cat([h_src, h_dst], dim=-1) h = torch.cat([h_src, h_dst], dim=-1)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
class TransE(nn.Module): class TransE(nn.Module):
r"""Similarity measure from `Translating Embeddings for Modeling Multi-relational Data r"""Similarity measure from `Translating Embeddings for Modeling Multi-relational Data
<https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html>`__ <https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html>`__
...@@ -53,6 +54,7 @@ class TransE(nn.Module): ...@@ -53,6 +54,7 @@ class TransE(nn.Module):
>>> scorer(h_head, h_tail, rels).shape >>> scorer(h_head, h_tail, rels).shape
torch.Size([30]) torch.Size([30])
""" """
def __init__(self, num_rels, feats, p=1): def __init__(self, num_rels, feats, p=1):
super(TransE, self).__init__() super(TransE, self).__init__()
...@@ -94,4 +96,4 @@ class TransE(nn.Module): ...@@ -94,4 +96,4 @@ class TransE(nn.Module):
""" """
h_rel = self.rel_emb(rels) h_rel = self.rel_emb(rels)
return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1) return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
class TransR(nn.Module): class TransR(nn.Module):
r"""Similarity measure from r"""Similarity measure from
`Learning entity and relation embeddings for knowledge graph completion `Learning entity and relation embeddings for knowledge graph completion
...@@ -58,6 +59,7 @@ class TransR(nn.Module): ...@@ -58,6 +59,7 @@ class TransR(nn.Module):
>>> scorer(h_head, h_tail, rels).shape >>> scorer(h_head, h_tail, rels).shape
torch.Size([30]) torch.Size([30])
""" """
def __init__(self, num_rels, rfeats, nfeats, p=1): def __init__(self, num_rels, rfeats, nfeats, p=1):
super(TransR, self).__init__() super(TransR, self).__init__()
...@@ -103,4 +105,4 @@ class TransR(nn.Module): ...@@ -103,4 +105,4 @@ class TransR(nn.Module):
h_head = (h_head.unsqueeze(1) @ proj_rel).squeeze(1) h_head = (h_head.unsqueeze(1) @ proj_rel).squeeze(1)
h_tail = (h_tail.unsqueeze(1) @ proj_rel).squeeze(1) h_tail = (h_tail.unsqueeze(1) @ proj_rel).squeeze(1)
return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1) return -torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
"""Torch NodeEmbedding.""" """Torch NodeEmbedding."""
from datetime import timedelta from datetime import timedelta
import torch as th import torch as th
from ...backend import pytorch as F from ...backend import pytorch as F
from ...utils import get_shared_mem_array, create_shared_mem_array
from ...cuda import nccl from ...cuda import nccl
from ...partition import NDArrayPartition from ...partition import NDArrayPartition
from ...utils import create_shared_mem_array, get_shared_mem_array
_STORE = None _STORE = None
_COMM = None _COMM = None
class NodeEmbedding: # NodeEmbedding
'''Class for storing node embeddings. class NodeEmbedding: # NodeEmbedding
"""Class for storing node embeddings.
The class is optimized for training large-scale node embeddings. It updates the embedding in The class is optimized for training large-scale node embeddings. It updates the embedding in
a sparse way and can scale to graphs with millions of nodes. It also supports partitioning a sparse way and can scale to graphs with millions of nodes. It also supports partitioning
...@@ -63,15 +66,22 @@ class NodeEmbedding: # NodeEmbedding ...@@ -63,15 +66,22 @@ class NodeEmbedding: # NodeEmbedding
... loss = F.sum(feats + 1, 0) ... loss = F.sum(feats + 1, 0)
... loss.backward() ... loss.backward()
... optimizer.step() ... optimizer.step()
''' """
def __init__(self, num_embeddings, embedding_dim, name, def __init__(
init_func=None, device=None, partition=None): self,
num_embeddings,
embedding_dim,
name,
init_func=None,
device=None,
partition=None,
):
global _STORE global _STORE
global _COMM global _COMM
if device is None: if device is None:
device = th.device('cpu') device = th.device("cpu")
# Check whether it is multi-gpu training or not. # Check whether it is multi-gpu training or not.
if th.distributed.is_initialized(): if th.distributed.is_initialized():
...@@ -86,7 +96,7 @@ class NodeEmbedding: # NodeEmbedding ...@@ -86,7 +96,7 @@ class NodeEmbedding: # NodeEmbedding
self._comm = None self._comm = None
self._partition = partition self._partition = partition
host_name = '127.0.0.1' host_name = "127.0.0.1"
port = 12346 port = 12346
if rank >= 0: if rank >= 0:
...@@ -94,25 +104,34 @@ class NodeEmbedding: # NodeEmbedding ...@@ -94,25 +104,34 @@ class NodeEmbedding: # NodeEmbedding
# embeding status synchronization across GPU processes # embeding status synchronization across GPU processes
if _STORE is None: if _STORE is None:
_STORE = th.distributed.TCPStore( _STORE = th.distributed.TCPStore(
host_name, port, world_size, rank == 0, timedelta(seconds=10*60)) host_name,
port,
world_size,
rank == 0,
timedelta(seconds=10 * 60),
)
self._store = _STORE self._store = _STORE
# embeddings is stored in CPU memory. # embeddings is stored in CPU memory.
if th.device(device) == th.device('cpu'): if th.device(device) == th.device("cpu"):
if rank <= 0: if rank <= 0:
emb = create_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32) emb = create_shared_mem_array(
name, (num_embeddings, embedding_dim), th.float32
)
if init_func is not None: if init_func is not None:
emb = init_func(emb) emb = init_func(emb)
if rank == 0: # the master gpu process if rank == 0: # the master gpu process
for _ in range(1, world_size): for _ in range(1, world_size):
# send embs # send embs
self._store.set(name, name) self._store.set(name, name)
elif rank > 0: elif rank > 0:
# receive # receive
self._store.wait([name]) self._store.wait([name])
emb = get_shared_mem_array(name, (num_embeddings, embedding_dim), th.float32) emb = get_shared_mem_array(
name, (num_embeddings, embedding_dim), th.float32
)
self._tensor = emb self._tensor = emb
else: # embeddings is stored in GPU memory. else: # embeddings is stored in GPU memory.
# setup nccl communicator # setup nccl communicator
if _COMM is None: if _COMM is None:
if rank < 0: if rank < 0:
...@@ -123,11 +142,14 @@ class NodeEmbedding: # NodeEmbedding ...@@ -123,11 +142,14 @@ class NodeEmbedding: # NodeEmbedding
if rank == 0: if rank == 0:
# root process broadcasts nccl id # root process broadcasts nccl id
nccl_id = nccl.UniqueId() nccl_id = nccl.UniqueId()
self._store.set('nccl_root_id_sparse_emb', str(nccl_id)) self._store.set("nccl_root_id_sparse_emb", str(nccl_id))
else: else:
nccl_id = nccl.UniqueId(self._store.get('nccl_root_id_sparse_emb')) nccl_id = nccl.UniqueId(
_COMM = nccl.Communicator(self._world_size, self._rank, self._store.get("nccl_root_id_sparse_emb")
nccl_id) )
_COMM = nccl.Communicator(
self._world_size, self._rank, nccl_id
)
self._comm = _COMM self._comm = _COMM
if not self._partition: if not self._partition:
...@@ -135,14 +157,19 @@ class NodeEmbedding: # NodeEmbedding ...@@ -135,14 +157,19 @@ class NodeEmbedding: # NodeEmbedding
self._partition = NDArrayPartition( self._partition = NDArrayPartition(
num_embeddings, num_embeddings,
self._world_size if self._world_size > 0 else 1, self._world_size if self._world_size > 0 else 1,
mode='remainder') mode="remainder",
)
# create local tensors for the weights # create local tensors for the weights
local_size = self._partition.local_size(self._comm.rank()) local_size = self._partition.local_size(self._comm.rank())
# TODO(dlasalle): support 16-bit/half embeddings # TODO(dlasalle): support 16-bit/half embeddings
emb = th.empty([local_size, embedding_dim], dtype=th.float32, emb = th.empty(
requires_grad=False, device=device) [local_size, embedding_dim],
dtype=th.float32,
requires_grad=False,
device=device,
)
if init_func: if init_func:
emb = init_func(emb) emb = init_func(emb)
self._tensor = emb self._tensor = emb
...@@ -150,10 +177,10 @@ class NodeEmbedding: # NodeEmbedding ...@@ -150,10 +177,10 @@ class NodeEmbedding: # NodeEmbedding
self._num_embeddings = num_embeddings self._num_embeddings = num_embeddings
self._embedding_dim = embedding_dim self._embedding_dim = embedding_dim
self._name = name self._name = name
self._optm_state = None # track optimizer state self._optm_state = None # track optimizer state
self._trace = [] # track minibatch self._trace = [] # track minibatch
def __call__(self, node_ids, device=th.device('cpu')): def __call__(self, node_ids, device=th.device("cpu")):
""" """
node_ids : th.tensor node_ids : th.tensor
Index of the embeddings to collect. Index of the embeddings to collect.
...@@ -165,7 +192,8 @@ class NodeEmbedding: # NodeEmbedding ...@@ -165,7 +192,8 @@ class NodeEmbedding: # NodeEmbedding
else: else:
if self.world_size > 0: if self.world_size > 0:
emb = self._comm.sparse_all_to_all_pull( emb = self._comm.sparse_all_to_all_pull(
node_ids, self._tensor, self._partition) node_ids, self._tensor, self._partition
)
else: else:
emb = self._tensor[node_ids] emb = self._tensor[node_ids]
emb = emb.to(device) emb = emb.to(device)
...@@ -331,7 +359,7 @@ class NodeEmbedding: # NodeEmbedding ...@@ -331,7 +359,7 @@ class NodeEmbedding: # NodeEmbedding
return self._tensor return self._tensor
def all_set_embedding(self, values): def all_set_embedding(self, values):
""" Set the values of the embedding. This method must be called by all """Set the values of the embedding. This method must be called by all
processes sharing the embedding with identical tensors for processes sharing the embedding with identical tensors for
:attr:`values`. :attr:`values`.
...@@ -346,20 +374,23 @@ class NodeEmbedding: # NodeEmbedding ...@@ -346,20 +374,23 @@ class NodeEmbedding: # NodeEmbedding
if self._partition: if self._partition:
idxs = F.copy_to( idxs = F.copy_to(
self._partition.get_local_indices( self._partition.get_local_indices(
self._comm.rank(), self._comm.rank(), ctx=F.context(self._tensor)
ctx=F.context(self._tensor)), ),
F.context(values)) F.context(values),
self._tensor[:] = F.copy_to(F.gather_row(values, idxs), )
ctx=F.context(self._tensor))[:] self._tensor[:] = F.copy_to(
F.gather_row(values, idxs), ctx=F.context(self._tensor)
)[:]
else: else:
if self._rank == 0: if self._rank == 0:
self._tensor[:] = F.copy_to(values, self._tensor[:] = F.copy_to(
ctx=F.context(self._tensor))[:] values, ctx=F.context(self._tensor)
)[:]
if th.distributed.is_initialized(): if th.distributed.is_initialized():
th.distributed.barrier() th.distributed.barrier()
def all_get_embedding(self): def all_get_embedding(self):
""" Return a copy of the embedding stored in CPU memory. If this is a """Return a copy of the embedding stored in CPU memory. If this is a
multi-processing instance, the tensor will be returned in shared multi-processing instance, the tensor will be returned in shared
memory. If the embedding is currently stored on multiple GPUs, all memory. If the embedding is currently stored on multiple GPUs, all
processes must call this method in the same order. processes must call this method in the same order.
...@@ -375,7 +406,7 @@ class NodeEmbedding: # NodeEmbedding ...@@ -375,7 +406,7 @@ class NodeEmbedding: # NodeEmbedding
if self._partition: if self._partition:
if self._world_size == 0: if self._world_size == 0:
# non-multiprocessing # non-multiprocessing
return self._tensor.to(th.device('cpu')) return self._tensor.to(th.device("cpu"))
else: else:
# create a shared memory tensor # create a shared memory tensor
shared_name = self._name + "_gather" shared_name = self._name + "_gather"
...@@ -384,18 +415,23 @@ class NodeEmbedding: # NodeEmbedding ...@@ -384,18 +415,23 @@ class NodeEmbedding: # NodeEmbedding
emb = create_shared_mem_array( emb = create_shared_mem_array(
shared_name, shared_name,
(self._num_embeddings, self._embedding_dim), (self._num_embeddings, self._embedding_dim),
self._tensor.dtype) self._tensor.dtype,
)
self._store.set(shared_name, shared_name) self._store.set(shared_name, shared_name)
else: else:
self._store.wait([shared_name]) self._store.wait([shared_name])
emb = get_shared_mem_array( emb = get_shared_mem_array(
shared_name, (self._num_embeddings, self._embedding_dim), shared_name,
self._tensor.dtype) (self._num_embeddings, self._embedding_dim),
self._tensor.dtype,
)
# need to map indices and slice into existing tensor # need to map indices and slice into existing tensor
idxs = self._partition.map_to_global( idxs = self._partition.map_to_global(
F.arange(0, self._tensor.shape[0], F.arange(
ctx=F.context(self._tensor)), 0, self._tensor.shape[0], ctx=F.context(self._tensor)
self._rank).to(emb.device) ),
self._rank,
).to(emb.device)
emb[idxs] = self._tensor.to(emb.device) emb[idxs] = self._tensor.to(emb.device)
# wait for all processes to finish # wait for all processes to finish
......
"""Utilities for pytorch NN package""" """Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
import torch as th import torch as th
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from ... import DGLGraph from ... import DGLGraph
from ...base import dgl_warning
from ... import function as fn from ... import function as fn
from ...base import dgl_warning
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector. """Perform Matrix multiplication C = A * B but A could be an integer id vector.
...@@ -49,6 +51,7 @@ def matmul_maybe_select(A, B): ...@@ -49,6 +51,7 @@ def matmul_maybe_select(A, B):
else: else:
return th.matmul(A, B) return th.matmul(A, B)
def bmm_maybe_select(A, B, index): def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm. """Slice submatrices of A by the given index and perform bmm.
...@@ -92,12 +95,14 @@ def bmm_maybe_select(A, B, index): ...@@ -92,12 +95,14 @@ def bmm_maybe_select(A, B, index):
BB = B.index_select(0, index) BB = B.index_select(0, index)
return th.bmm(A.unsqueeze(1), BB).squeeze() return th.bmm(A.unsqueeze(1), BB).squeeze()
# pylint: disable=W0235 # pylint: disable=W0235
class Identity(nn.Module): class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive. """A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly (Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future) import torch.nn.Identity in the future)
""" """
def __init__(self): def __init__(self):
super(Identity, self).__init__() super(Identity, self).__init__()
...@@ -105,6 +110,7 @@ class Identity(nn.Module): ...@@ -105,6 +110,7 @@ class Identity(nn.Module):
"""Return input""" """Return input"""
return x return x
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
r"""A sequential container for stacking graph neural network modules r"""A sequential container for stacking graph neural network modules
...@@ -220,10 +226,13 @@ class Sequential(nn.Sequential): ...@@ -220,10 +226,13 @@ class Sequential(nn.Sequential):
feats = (feats,) feats = (feats,)
feats = module(graph, *feats) feats = module(graph, *feats)
else: else:
raise TypeError('The first argument of forward must be a DGLGraph' raise TypeError(
' or a list of DGLGraph s') "The first argument of forward must be a DGLGraph"
" or a list of DGLGraph s"
)
return feats return feats
class WeightBasis(nn.Module): class WeightBasis(nn.Module):
r"""Basis decomposition from `Modeling Relational Data with Graph r"""Basis decomposition from `Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__ Convolutional Networks <https://arxiv.org/abs/1703.06103>`__
...@@ -249,24 +258,28 @@ class WeightBasis(nn.Module): ...@@ -249,24 +258,28 @@ class WeightBasis(nn.Module):
num_outputs : int num_outputs : int
Number of outputs. Number of outputs.
""" """
def __init__(self,
shape, def __init__(self, shape, num_bases, num_outputs):
num_bases,
num_outputs):
super(WeightBasis, self).__init__() super(WeightBasis, self).__init__()
self.shape = shape self.shape = shape
self.num_bases = num_bases self.num_bases = num_bases
self.num_outputs = num_outputs self.num_outputs = num_outputs
if num_outputs <= num_bases: if num_outputs <= num_bases:
dgl_warning('The number of weight outputs should be larger than the number' dgl_warning(
' of bases.') "The number of weight outputs should be larger than the number"
" of bases."
)
self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape)) self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# linear combination coefficients # linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases)) self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(
self.w_comp, gain=nn.init.calculate_gain("relu")
)
def forward(self): def forward(self):
r"""Forward computation r"""Forward computation
...@@ -280,6 +293,7 @@ class WeightBasis(nn.Module): ...@@ -280,6 +293,7 @@ class WeightBasis(nn.Module):
weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1)) weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
return weight.view(self.num_outputs, *self.shape) return weight.view(self.num_outputs, *self.shape)
class JumpingKnowledge(nn.Module): class JumpingKnowledge(nn.Module):
r"""The Jumping Knowledge aggregation module from `Representation Learning on r"""The Jumping Knowledge aggregation module from `Representation Learning on
Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__ Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__
...@@ -345,17 +359,25 @@ class JumpingKnowledge(nn.Module): ...@@ -345,17 +359,25 @@ class JumpingKnowledge(nn.Module):
>>> model(feat_list).shape >>> model(feat_list).shape
torch.Size([3, 4]) torch.Size([3, 4])
""" """
def __init__(self, mode='cat', in_feats=None, num_layers=None):
def __init__(self, mode="cat", in_feats=None, num_layers=None):
super(JumpingKnowledge, self).__init__() super(JumpingKnowledge, self).__init__()
assert mode in ['cat', 'max', 'lstm'], \ assert mode in [
"Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode) "cat",
"max",
"lstm",
], "Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
self.mode = mode self.mode = mode
if mode == 'lstm': if mode == "lstm":
assert in_feats is not None, 'in_feats is required for lstm mode' assert in_feats is not None, "in_feats is required for lstm mode"
assert num_layers is not None, 'num_layers is required for lstm mode' assert (
num_layers is not None
), "num_layers is required for lstm mode"
hidden_size = (num_layers * in_feats) // 2 hidden_size = (num_layers * in_feats) // 2
self.lstm = nn.LSTM(in_feats, hidden_size, bidirectional=True, batch_first=True) self.lstm = nn.LSTM(
in_feats, hidden_size, bidirectional=True, batch_first=True
)
self.att = nn.Linear(2 * hidden_size, 1) self.att = nn.Linear(2 * hidden_size, 1)
def reset_parameters(self): def reset_parameters(self):
...@@ -365,7 +387,7 @@ class JumpingKnowledge(nn.Module): ...@@ -365,7 +387,7 @@ class JumpingKnowledge(nn.Module):
----------- -----------
Reinitialize learnable parameters. This comes into effect only for the lstm mode. Reinitialize learnable parameters. This comes into effect only for the lstm mode.
""" """
if self.mode == 'lstm': if self.mode == "lstm":
self.lstm.reset_parameters() self.lstm.reset_parameters()
self.att.reset_parameters() self.att.reset_parameters()
...@@ -386,18 +408,21 @@ class JumpingKnowledge(nn.Module): ...@@ -386,18 +408,21 @@ class JumpingKnowledge(nn.Module):
Tensor Tensor
The aggregated representations. The aggregated representations.
""" """
if self.mode == 'cat': if self.mode == "cat":
return th.cat(feat_list, dim=-1) return th.cat(feat_list, dim=-1)
elif self.mode == 'max': elif self.mode == "max":
return th.stack(feat_list, dim=-1).max(dim=-1)[0] return th.stack(feat_list, dim=-1).max(dim=-1)[0]
else: else:
# LSTM # LSTM
stacked_feat_list = th.stack(feat_list, dim=1) # (N, num_layers, in_feats) stacked_feat_list = th.stack(
feat_list, dim=1
) # (N, num_layers, in_feats)
alpha, _ = self.lstm(stacked_feat_list) alpha, _ = self.lstm(stacked_feat_list)
alpha = self.att(alpha).squeeze(-1) # (N, num_layers) alpha = self.att(alpha).squeeze(-1) # (N, num_layers)
alpha = th.softmax(alpha, dim=-1) alpha = th.softmax(alpha, dim=-1)
return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1) return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)
class LabelPropagation(nn.Module): class LabelPropagation(nn.Module):
r"""Label Propagation from `Learning from Labeled and Unlabeled Data with Label r"""Label Propagation from `Learning from Labeled and Unlabeled Data with Label
Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__ Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__
...@@ -447,7 +472,16 @@ class LabelPropagation(nn.Module): ...@@ -447,7 +472,16 @@ class LabelPropagation(nn.Module):
>>> mask = torch.tensor([0, 1, 1, 1, 0]).bool() >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
>>> new_labels = label_propagation(g, labels, mask) >>> new_labels = label_propagation(g, labels, mask)
""" """
def __init__(self, k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False):
def __init__(
self,
k,
alpha,
norm_type="sym",
clamp=True,
normalize=False,
reset=False,
):
super(LabelPropagation, self).__init__() super(LabelPropagation, self).__init__()
self.k = k self.k = k
self.alpha = alpha self.alpha = alpha
...@@ -498,21 +532,23 @@ class LabelPropagation(nn.Module): ...@@ -498,21 +532,23 @@ class LabelPropagation(nn.Module):
init = (1 - self.alpha) * y init = (1 - self.alpha) * y
in_degs = g.in_degrees().float().clamp(min=1) in_degs = g.in_degrees().float().clamp(min=1)
out_degs = g.out_degrees().float().clamp(min=1) out_degs = g.out_degrees().float().clamp(min=1)
if self.norm_type == 'sym': if self.norm_type == "sym":
norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1) norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1) norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
elif self.norm_type == 'row': elif self.norm_type == "row":
norm_i = th.pow(in_degs, -1.).to(labels.device).unsqueeze(1) norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1)
else: else:
raise ValueError(f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}") raise ValueError(
f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}"
)
for _ in range(self.k): for _ in range(self.k):
g.ndata['h'] = y * norm_j if self.norm_type == 'sym' else y g.ndata["h"] = y * norm_j if self.norm_type == "sym" else y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
y = init + self.alpha * g.ndata['h'] * norm_i y = init + self.alpha * g.ndata["h"] * norm_i
if self.clamp: if self.clamp:
y = y.clamp_(0., 1.) y = y.clamp_(0.0, 1.0)
if self.normalize: if self.normalize:
y = F.normalize(y, p=1) y = F.normalize(y, p=1)
if self.reset: if self.reset:
......
"""Package for Tensorflow-specific NN modules.""" """Package for Tensorflow-specific NN modules."""
from .conv import * from .conv import *
from .softmax import *
from .utils import *
from .glob import * from .glob import *
from .hetero import * from .hetero import *
from .softmax import *
from .utils import *
"""TF NN conv module""" """TF NN conv module"""
from .gatconv import GATConv
from .relgraphconv import RelGraphConv
from .graphconv import GraphConv
from .ginconv import GINConv
from .sageconv import SAGEConv
from .sgconv import SGConv
from .appnpconv import APPNPConv from .appnpconv import APPNPConv
from .chebconv import ChebConv from .chebconv import ChebConv
from .densechebconv import DenseChebConv from .densechebconv import DenseChebConv
from .edgeconv import EdgeConv from .edgeconv import EdgeConv
from .gatconv import GATConv
from .ginconv import GINConv
from .graphconv import GraphConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .sgconv import SGConv
"""TF Module for APPNPConv""" """TF Module for APPNPConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
from .... import function as fn from .... import function as fn
...@@ -29,10 +29,7 @@ class APPNPConv(layers.Layer): ...@@ -29,10 +29,7 @@ class APPNPConv(layers.Layer):
messages received by each node. Default: ``0``. messages received by each node. Default: ``0``.
""" """
def __init__(self, def __init__(self, k, alpha, edge_drop=0.0):
k,
alpha,
edge_drop=0.):
super(APPNPConv, self).__init__() super(APPNPConv, self).__init__()
self._k = k self._k = k
self._alpha = alpha self._alpha = alpha
...@@ -56,8 +53,11 @@ class APPNPConv(layers.Layer): ...@@ -56,8 +53,11 @@ class APPNPConv(layers.Layer):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), degs = tf.clip_by_value(
clip_value_min=1, clip_value_max=np.inf) tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp) norm = tf.reshape(norm, shp)
...@@ -65,12 +65,12 @@ class APPNPConv(layers.Layer): ...@@ -65,12 +65,12 @@ class APPNPConv(layers.Layer):
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
graph.edata['w'] = self.edge_drop( graph.edata["w"] = self.edge_drop(
tf.ones(graph.number_of_edges(), 1)) tf.ones(graph.number_of_edges(), 1)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), )
fn.sum('m', 'h')) graph.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
feat = graph.ndata.pop('h') feat = graph.ndata.pop("h")
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
......
"""Tensorflow Module for Chebyshev Spectral Graph Convolution layer""" """Tensorflow Module for Chebyshev Spectral Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
from .... import broadcast_nodes
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
from .... import broadcast_nodes, function as fn
class ChebConv(layers.Layer): class ChebConv(layers.Layer):
...@@ -60,12 +61,9 @@ class ChebConv(layers.Layer): ...@@ -60,12 +61,9 @@ class ChebConv(layers.Layer):
[-0.2370, 3.0164]], dtype=float32)> [-0.2370, 3.0164]], dtype=float32)>
""" """
def __init__(self, def __init__(
in_feats, self, in_feats, out_feats, k, activation=tf.nn.relu, bias=True
out_feats, ):
k,
activation=tf.nn.relu,
bias=True):
super(ChebConv, self).__init__() super(ChebConv, self).__init__()
self._k = k self._k = k
self._in_feats = in_feats self._in_feats = in_feats
...@@ -97,33 +95,38 @@ class ChebConv(layers.Layer): ...@@ -97,33 +95,38 @@ class ChebConv(layers.Layer):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
def unnLaplacian(feat, D_invsqrt, graph): def unnLaplacian(feat, D_invsqrt, graph):
""" Operation Feat * D^-1/2 A D^-1/2 """ """Operation Feat * D^-1/2 A D^-1/2"""
graph.ndata['h'] = feat * D_invsqrt graph.ndata["h"] = feat * D_invsqrt
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
return graph.ndata.pop('h') * D_invsqrt return graph.ndata.pop("h") * D_invsqrt
with graph.local_scope(): with graph.local_scope():
in_degrees = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), in_degrees = tf.clip_by_value(
clip_value_min=1, tf.cast(graph.in_degrees(), tf.float32),
clip_value_max=np.inf) clip_value_min=1,
clip_value_max=np.inf,
)
D_invsqrt = tf.expand_dims(tf.pow(in_degrees, -0.5), axis=-1) D_invsqrt = tf.expand_dims(tf.pow(in_degrees, -0.5), axis=-1)
if lambda_max is None: if lambda_max is None:
dgl_warning( dgl_warning(
"lambda_max is not provided, using default value of 2. " "lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.") "Please use dgl.laplacian_lambda_max to compute the eigenvalues."
)
lambda_max = [2] * graph.batch_size lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list): if isinstance(lambda_max, list):
lambda_max = tf.constant(lambda_max, dtype=tf.float32) lambda_max = tf.constant(lambda_max, dtype=tf.float32)
if lambda_max.ndim == 1: if lambda_max.ndim == 1:
lambda_max = tf.expand_dims( lambda_max = tf.expand_dims(
lambda_max, axis=-1) # (B,) to (B, 1) lambda_max, axis=-1
) # (B,) to (B, 1)
# broadcast from (B, 1) to (N, 1) # broadcast from (B, 1) to (N, 1)
lambda_max = broadcast_nodes(graph, lambda_max) lambda_max = broadcast_nodes(graph, lambda_max)
re_norm = 2. / lambda_max re_norm = 2.0 / lambda_max
# X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t # X_0 is the raw feature, Xt refers to the concatenation of X_0, X_1, ... X_t
Xt = X_0 = feat Xt = X_0 = feat
...@@ -131,14 +134,14 @@ class ChebConv(layers.Layer): ...@@ -131,14 +134,14 @@ class ChebConv(layers.Layer):
# X_1(f) # X_1(f)
if self._k > 1: if self._k > 1:
h = unnLaplacian(X_0, D_invsqrt, graph) h = unnLaplacian(X_0, D_invsqrt, graph)
X_1 = - re_norm * h + X_0 * (re_norm - 1) X_1 = -re_norm * h + X_0 * (re_norm - 1)
# Concatenate Xt and X_1 # Concatenate Xt and X_1
Xt = tf.concat((Xt, X_1), 1) Xt = tf.concat((Xt, X_1), 1)
# Xi(x), i = 2...k # Xi(x), i = 2...k
for _ in range(2, self._k): for _ in range(2, self._k):
h = unnLaplacian(X_1, D_invsqrt, graph) h = unnLaplacian(X_1, D_invsqrt, graph)
X_i = - 2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0 X_i = -2 * re_norm * h + X_1 * 2 * (re_norm - 1) - X_0
# Concatenate Xt and X_i # Concatenate Xt and X_i
Xt = tf.concat((Xt, X_i), 1) Xt = tf.concat((Xt, X_i), 1)
X_1, X_0 = X_i, X_1 X_1, X_0 = X_i, X_1
......
"""Tensorflow Module for DenseChebConv""" """Tensorflow Module for DenseChebConv"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
class DenseChebConv(layers.Layer): class DenseChebConv(layers.Layer):
...@@ -30,11 +30,7 @@ class DenseChebConv(layers.Layer): ...@@ -30,11 +30,7 @@ class DenseChebConv(layers.Layer):
`ChebConv <https://docs.dgl.ai/api/python/nn.tensorflow.html#chebconv>`__ `ChebConv <https://docs.dgl.ai/api/python/nn.tensorflow.html#chebconv>`__
""" """
def __init__(self, def __init__(self, in_feats, out_feats, k, bias=True):
in_feats,
out_feats,
k,
bias=True):
super(DenseChebConv, self).__init__() super(DenseChebConv, self).__init__()
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
...@@ -42,13 +38,19 @@ class DenseChebConv(layers.Layer): ...@@ -42,13 +38,19 @@ class DenseChebConv(layers.Layer):
# keras initializer assume last two dims as fan_in and fan_out # keras initializer assume last two dims as fan_in and fan_out
xinit = tf.keras.initializers.glorot_normal() xinit = tf.keras.initializers.glorot_normal()
self.W = tf.Variable(initial_value=xinit( self.W = tf.Variable(
shape=(k, in_feats, out_feats), dtype='float32'), trainable=True) initial_value=xinit(
shape=(k, in_feats, out_feats), dtype="float32"
),
trainable=True,
)
if bias: if bias:
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
self.bias = tf.Variable(initial_value=zeroinit( self.bias = tf.Variable(
shape=(out_feats), dtype='float32'), trainable=True) initial_value=zeroinit(shape=(out_feats), dtype="float32"),
trainable=True,
)
else: else:
self.bias = None self.bias = None
...@@ -76,9 +78,11 @@ class DenseChebConv(layers.Layer): ...@@ -76,9 +78,11 @@ class DenseChebConv(layers.Layer):
""" """
A = adj A = adj
num_nodes = A.shape[0] num_nodes = A.shape[0]
in_degree = 1 / tf.sqrt(tf.clip_by_value(tf.reduce_sum(A, 1), in_degree = 1 / tf.sqrt(
clip_value_min=1, tf.clip_by_value(
clip_value_max=np.inf)) tf.reduce_sum(A, 1), clip_value_min=1, clip_value_max=np.inf
)
)
D_invsqrt = tf.linalg.diag(in_degree) D_invsqrt = tf.linalg.diag(in_degree)
I = tf.eye(num_nodes) I = tf.eye(num_nodes)
L = I - D_invsqrt @ A @ D_invsqrt L = I - D_invsqrt @ A @ D_invsqrt
...@@ -97,7 +101,7 @@ class DenseChebConv(layers.Layer): ...@@ -97,7 +101,7 @@ class DenseChebConv(layers.Layer):
Zs = tf.stack(Z, 0) # (k, n, n) Zs = tf.stack(Z, 0) # (k, n, n)
Zh = (Zs @ tf.expand_dims(feat, axis=0) @ self.W) Zh = Zs @ tf.expand_dims(feat, axis=0) @ self.W
Zh = tf.reduce_sum(Zh, 0) Zh = tf.reduce_sum(Zh, 0)
if self.bias is not None: if self.bias is not None:
......
...@@ -60,10 +60,8 @@ class EdgeConv(layers.Layer): ...@@ -60,10 +60,8 @@ class EdgeConv(layers.Layer):
A common practise to handle this is to filter out the nodes with zere-in-degree when use A common practise to handle this is to filter out the nodes with zere-in-degree when use
after conv. after conv.
""" """
def __init__(self,
out_feats, def __init__(self, out_feats, batch_norm=False, allow_zero_in_degree=False):
batch_norm=False,
allow_zero_in_degree=False):
super(EdgeConv, self).__init__() super(EdgeConv, self).__init__()
self.batch_norm = batch_norm self.batch_norm = batch_norm
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
...@@ -111,29 +109,31 @@ class EdgeConv(layers.Layer): ...@@ -111,29 +109,31 @@ class EdgeConv(layers.Layer):
""" """
with g.local_scope(): with g.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if tf.math.count_nonzero(g.in_degrees() == 0) > 0: if tf.math.count_nonzero(g.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
h_src, h_dst = expand_as_pair(feat, g) h_src, h_dst = expand_as_pair(feat, g)
g.srcdata['x'] = h_src g.srcdata["x"] = h_src
g.dstdata['x'] = h_dst g.dstdata["x"] = h_dst
g.apply_edges(fn.v_sub_u('x', 'x', 'theta')) g.apply_edges(fn.v_sub_u("x", "x", "theta"))
g.edata['theta'] = self.theta(g.edata['theta']) g.edata["theta"] = self.theta(g.edata["theta"])
g.dstdata['phi'] = self.phi(g.dstdata['x']) g.dstdata["phi"] = self.phi(g.dstdata["x"])
if not self.batch_norm: if not self.batch_norm:
g.update_all(fn.e_add_v('theta', 'phi', 'e'), fn.max('e', 'x')) g.update_all(fn.e_add_v("theta", "phi", "e"), fn.max("e", "x"))
else: else:
g.apply_edges(fn.e_add_v('theta', 'phi', 'e')) g.apply_edges(fn.e_add_v("theta", "phi", "e"))
# for more comments on why global batch norm instead # for more comments on why global batch norm instead
# of batch norm within EdgeConv go to # of batch norm within EdgeConv go to
# https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/edgeconv.py # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/edgeconv.py
g.edata['e'] = self.bn(g.edata['e']) g.edata["e"] = self.bn(g.edata["e"])
g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x')) g.update_all(fn.copy_e("e", "e"), fn.max("e", "x"))
return g.dstdata['x'] return g.dstdata["x"]
...@@ -57,23 +57,26 @@ class GINConv(layers.Layer): ...@@ -57,23 +57,26 @@ class GINConv(layers.Layer):
0.55207 , 1.2442873 , -0.17693758, 0.67841303, 0.8633929 ]], 0.55207 , 1.2442873 , -0.17693758, 0.67841303, 0.8633929 ]],
dtype=float32)> dtype=float32)>
""" """
def __init__(self,
apply_func, def __init__(
aggregator_type, self, apply_func, aggregator_type, init_eps=0, learn_eps=False
init_eps=0, ):
learn_eps=False):
super(GINConv, self).__init__() super(GINConv, self).__init__()
self.apply_func = apply_func self.apply_func = apply_func
if aggregator_type == 'sum': if aggregator_type == "sum":
self._reducer = fn.sum self._reducer = fn.sum
elif aggregator_type == 'max': elif aggregator_type == "max":
self._reducer = fn.max self._reducer = fn.max
elif aggregator_type == 'mean': elif aggregator_type == "mean":
self._reducer = fn.mean self._reducer = fn.mean
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) raise KeyError(
"Aggregator type {} not recognized.".format(aggregator_type)
)
# to specify whether eps is trainable or not. # to specify whether eps is trainable or not.
self.eps = tf.Variable(initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps) self.eps = tf.Variable(
initial_value=[init_eps], dtype=tf.float32, trainable=learn_eps
)
def call(self, graph, feat): def call(self, graph, feat):
r"""Compute Graph Isomorphism Network layer. r"""Compute Graph Isomorphism Network layer.
...@@ -100,9 +103,9 @@ class GINConv(layers.Layer): ...@@ -100,9 +103,9 @@ class GINConv(layers.Layer):
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u("h", "m"), self._reducer("m", "neigh"))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -108,60 +108,84 @@ class RelGraphConv(layers.Layer): ...@@ -108,60 +108,84 @@ class RelGraphConv(layers.Layer):
[-0.14293689, 0.77483284], [-0.14293689, 0.77483284],
[ 0.091169 , -0.06761569]], dtype=float32)> [ 0.091169 , -0.06761569]], dtype=float32)>
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
num_rels, in_feat,
regularizer="basis", out_feat,
num_bases=None, num_rels,
bias=True, regularizer="basis",
activation=None, num_bases=None,
self_loop=True, bias=True,
low_mem=False, activation=None,
dropout=0.0, self_loop=True,
layer_norm=False): low_mem=False,
dropout=0.0,
layer_norm=False,
):
super(RelGraphConv, self).__init__() super(RelGraphConv, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.num_rels = num_rels self.num_rels = num_rels
self.regularizer = regularizer self.regularizer = regularizer
self.num_bases = num_bases self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0: if (
self.num_bases is None
or self.num_bases > self.num_rels
or self.num_bases < 0
):
self.num_bases = self.num_rels self.num_bases = self.num_rels
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem self.low_mem = low_mem
assert layer_norm is False, 'TensorFlow currently does not support layer norm.' assert (
layer_norm is False
), "TensorFlow currently does not support layer norm."
xinit = tf.keras.initializers.glorot_uniform() xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
if regularizer == "basis": if regularizer == "basis":
# add basis weights # add basis weights
self.weight = tf.Variable(initial_value=xinit( self.weight = tf.Variable(
shape=(self.num_bases, self.in_feat, self.out_feat), initial_value=xinit(
dtype='float32'), trainable=True) shape=(self.num_bases, self.in_feat, self.out_feat),
dtype="float32",
),
trainable=True,
)
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# linear combination coefficients # linear combination coefficients
self.w_comp = tf.Variable(initial_value=xinit( self.w_comp = tf.Variable(
shape=(self.num_rels, self.num_bases), dtype='float32'), trainable=True) initial_value=xinit(
shape=(self.num_rels, self.num_bases), dtype="float32"
),
trainable=True,
)
# message func # message func
self.message_func = self.basis_message_func self.message_func = self.basis_message_func
elif regularizer == "bdd": elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0: if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError( raise ValueError(
'Feature size must be a multiplier of num_bases.') "Feature size must be a multiplier of num_bases."
)
# add block diagonal weights # add block diagonal weights
self.submat_in = in_feat // self.num_bases self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases # assuming in_feat and out_feat are both divisible by num_bases
self.weight = tf.Variable(initial_value=xinit( self.weight = tf.Variable(
shape=(self.num_rels, self.num_bases * initial_value=xinit(
self.submat_in * self.submat_out), shape=(
dtype='float32'), trainable=True) self.num_rels,
self.num_bases * self.submat_in * self.submat_out,
),
dtype="float32",
),
trainable=True,
)
# message func # message func
self.message_func = self.bdd_message_func self.message_func = self.bdd_message_func
else: else:
...@@ -169,13 +193,17 @@ class RelGraphConv(layers.Layer): ...@@ -169,13 +193,17 @@ class RelGraphConv(layers.Layer):
# bias # bias
if self.bias: if self.bias:
self.h_bias = tf.Variable(initial_value=zeroinit( self.h_bias = tf.Variable(
shape=(out_feat), dtype='float32'), trainable=True) initial_value=zeroinit(shape=(out_feat), dtype="float32"),
trainable=True,
)
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = tf.Variable(initial_value=xinit( self.loop_weight = tf.Variable(
shape=(in_feat, out_feat), dtype='float32'), trainable=True) initial_value=xinit(shape=(in_feat, out_feat), dtype="float32"),
trainable=True,
)
self.dropout = layers.Dropout(rate=dropout) self.dropout = layers.Dropout(rate=dropout)
...@@ -183,64 +211,76 @@ class RelGraphConv(layers.Layer): ...@@ -183,64 +211,76 @@ class RelGraphConv(layers.Layer):
"""Message function for basis regularizer""" """Message function for basis regularizer"""
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# generate all weights from bases # generate all weights from bases
weight = tf.reshape(self.weight, (self.num_bases, weight = tf.reshape(
self.in_feat * self.out_feat)) self.weight, (self.num_bases, self.in_feat * self.out_feat)
weight = tf.reshape(tf.matmul(self.w_comp, weight), ( )
self.num_rels, self.in_feat, self.out_feat)) weight = tf.reshape(
tf.matmul(self.w_comp, weight),
(self.num_rels, self.in_feat, self.out_feat),
)
else: else:
weight = self.weight weight = self.weight
# calculate msg @ W_r before put msg into edge # calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select # if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != tf.int64 and self.low_mem: if edges.src["h"].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data['type']) etypes, _ = tf.unique(edges.data["type"])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat]) msg = tf.zeros([edges.src["h"].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0]) idx = tf.range(edges.src["h"].shape[0])
for etype in etypes: for etype in etypes:
loc = (edges.data['type'] == etype) loc = edges.data["type"] == etype
w = weight[etype] w = weight[etype]
src = tf.boolean_mask(edges.src['h'], loc) src = tf.boolean_mask(edges.src["h"], loc)
sub_msg = tf.matmul(src, w) sub_msg = tf.matmul(src, w)
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1)) indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg) msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else: else:
msg = utils.bmm_maybe_select( msg = utils.bmm_maybe_select(
edges.src['h'], weight, edges.data['type']) edges.src["h"], weight, edges.data["type"]
if 'norm' in edges.data: )
msg = msg * edges.data['norm'] if "norm" in edges.data:
return {'msg': msg} msg = msg * edges.data["norm"]
return {"msg": msg}
def bdd_message_func(self, edges): def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer""" """Message function for block-diagonal-decomposition regularizer"""
if ((edges.src['h'].dtype == tf.int64) and if (edges.src["h"].dtype == tf.int64) and len(
len(edges.src['h'].shape) == 1): edges.src["h"].shape
) == 1:
raise TypeError( raise TypeError(
'Block decomposition does not allow integer ID feature.') "Block decomposition does not allow integer ID feature."
)
# calculate msg @ W_r before put msg into edge # calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select # if src is th.int64 we expect it is an index select
if self.low_mem: if self.low_mem:
etypes, _ = tf.unique(edges.data['type']) etypes, _ = tf.unique(edges.data["type"])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat]) msg = tf.zeros([edges.src["h"].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0]) idx = tf.range(edges.src["h"].shape[0])
for etype in etypes: for etype in etypes:
loc = (edges.data['type'] == etype) loc = edges.data["type"] == etype
w = tf.reshape(self.weight[etype], w = tf.reshape(
(self.num_bases, self.submat_in, self.submat_out)) self.weight[etype],
src = tf.reshape(tf.boolean_mask(edges.src['h'], loc), (self.num_bases, self.submat_in, self.submat_out),
(-1, self.num_bases, self.submat_in)) )
sub_msg = tf.einsum('abc,bcd->abd', src, w) src = tf.reshape(
tf.boolean_mask(edges.src["h"], loc),
(-1, self.num_bases, self.submat_in),
)
sub_msg = tf.einsum("abc,bcd->abd", src, w)
sub_msg = tf.reshape(sub_msg, (-1, self.out_feat)) sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1)) indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg) msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else: else:
weight = tf.reshape(tf.gather( weight = tf.reshape(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out)) tf.gather(self.weight, edges.data["type"]),
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in)) (-1, self.submat_in, self.submat_out),
)
node = tf.reshape(edges.src["h"], (-1, 1, self.submat_in))
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat)) msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
if 'norm' in edges.data: if "norm" in edges.data:
msg = msg * edges.data['norm'] msg = msg * edges.data["norm"]
return {'msg': msg} return {"msg": msg}
def call(self, g, x, etypes, norm=None): def call(self, g, x, etypes, norm=None):
"""Forward computation """Forward computation
...@@ -265,20 +305,21 @@ class RelGraphConv(layers.Layer): ...@@ -265,20 +305,21 @@ class RelGraphConv(layers.Layer):
tf.Tensor tf.Tensor
New node features. New node features.
""" """
assert g.is_homogeneous, \ assert g.is_homogeneous, (
"not a homogeneous graph; convert it with to_homogeneous " \ "not a homogeneous graph; convert it with to_homogeneous "
"and pass in the edge type as argument" "and pass in the edge type as argument"
)
with g.local_scope(): with g.local_scope():
g.ndata['h'] = x g.ndata["h"] = x
g.edata['type'] = tf.cast(etypes, tf.int64) g.edata["type"] = tf.cast(etypes, tf.int64)
if norm is not None: if norm is not None:
g.edata['norm'] = norm g.edata["norm"] = norm
if self.self_loop: if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight) loop_message = utils.matmul_maybe_select(x, self.loop_weight)
# message passing # message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h')) g.update_all(self.message_func, fn.sum(msg="msg", out="h"))
# apply bias and activation # apply bias and activation
node_repr = g.ndata['h'] node_repr = g.ndata["h"]
if self.bias: if self.bias:
node_repr = node_repr + self.h_bias node_repr = node_repr + self.h_bias
if self.self_loop: if self.self_loop:
......
"""tf Module for Simplifying Graph Convolution layer""" """tf Module for Simplifying Graph Convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0613 # pylint: disable= no-member, arguments-differ, invalid-name, W0613
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
...@@ -83,14 +83,17 @@ class SGConv(layers.Layer): ...@@ -83,14 +83,17 @@ class SGConv(layers.Layer):
[0.60570633, 0.520766 ], [0.60570633, 0.520766 ],
[0.6102368 , 0.52466124]], dtype=float32)> [0.6102368 , 0.52466124]], dtype=float32)>
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
k=1, in_feats,
cached=False, out_feats,
bias=True, k=1,
norm=None, cached=False,
allow_zero_in_degree=False): bias=True,
norm=None,
allow_zero_in_degree=False,
):
super(SGConv, self).__init__() super(SGConv, self).__init__()
self.fc = layers.Dense(out_feats, use_bias=bias) self.fc = layers.Dense(out_feats, use_bias=bias)
self._cached = cached self._cached = cached
...@@ -140,32 +143,36 @@ class SGConv(layers.Layer): ...@@ -140,32 +143,36 @@ class SGConv(layers.Layer):
""" """
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0: if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
# compute normalization # compute normalization
degs = tf.clip_by_value(tf.cast( degs = tf.clip_by_value(
graph.in_degrees(), tf.float32), clip_value_min=1, clip_value_max=np.inf) tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
norm = tf.expand_dims(norm, 1) norm = tf.expand_dims(norm, 1)
# compute (D^-1 A^k D)^k X # compute (D^-1 A^k D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata["h"] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
fn.sum('m', 'h')) feat = graph.ndata.pop("h")
feat = graph.ndata.pop('h')
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
......
...@@ -3,13 +3,22 @@ ...@@ -3,13 +3,22 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from ...readout import (
from ...readout import sum_nodes, mean_nodes, max_nodes, \ max_nodes,
softmax_nodes, topk_nodes mean_nodes,
softmax_nodes,
sum_nodes,
__all__ = ['SumPooling', 'AvgPooling', topk_nodes,
'MaxPooling', 'SortPooling', 'WeightAndSum', 'GlobalAttentionPooling'] )
__all__ = [
"SumPooling",
"AvgPooling",
"MaxPooling",
"SortPooling",
"WeightAndSum",
"GlobalAttentionPooling",
]
class SumPooling(layers.Layer): class SumPooling(layers.Layer):
...@@ -41,8 +50,8 @@ class SumPooling(layers.Layer): ...@@ -41,8 +50,8 @@ class SumPooling(layers.Layer):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = sum_nodes(graph, 'h') readout = sum_nodes(graph, "h")
return readout return readout
...@@ -74,8 +83,8 @@ class AvgPooling(layers.Layer): ...@@ -74,8 +83,8 @@ class AvgPooling(layers.Layer):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = mean_nodes(graph, 'h') readout = mean_nodes(graph, "h")
return readout return readout
...@@ -107,8 +116,8 @@ class MaxPooling(layers.Layer): ...@@ -107,8 +116,8 @@ class MaxPooling(layers.Layer):
:math:`B` refers to the batch size. :math:`B` refers to the batch size.
""" """
with graph.local_scope(): with graph.local_scope():
graph.ndata['h'] = feat graph.ndata["h"] = feat
readout = max_nodes(graph, 'h') readout = max_nodes(graph, "h")
return readout return readout
...@@ -146,10 +155,12 @@ class SortPooling(layers.Layer): ...@@ -146,10 +155,12 @@ class SortPooling(layers.Layer):
with graph.local_scope(): with graph.local_scope():
# Sort the feature of each node in ascending order. # Sort the feature of each node in ascending order.
feat = tf.sort(feat, -1) feat = tf.sort(feat, -1)
graph.ndata['h'] = feat graph.ndata["h"] = feat
# Sort nodes according to their last features. # Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, sortby=-1)[0], ( ret = tf.reshape(
-1, self.k * feat.shape[-1])) topk_nodes(graph, "h", self.k, sortby=-1)[0],
(-1, self.k * feat.shape[-1]),
)
return ret return ret
...@@ -194,16 +205,18 @@ class GlobalAttentionPooling(layers.Layer): ...@@ -194,16 +205,18 @@ class GlobalAttentionPooling(layers.Layer):
""" """
with graph.local_scope(): with graph.local_scope():
gate = self.gate_nn(feat) gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis." assert (
gate.shape[-1] == 1
), "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate graph.ndata["gate"] = gate
gate = softmax_nodes(graph, 'gate') gate = softmax_nodes(graph, "gate")
graph.ndata.pop('gate') graph.ndata.pop("gate")
graph.ndata['r'] = feat * gate graph.ndata["r"] = feat * gate
readout = sum_nodes(graph, 'r') readout = sum_nodes(graph, "r")
graph.ndata.pop('r') graph.ndata.pop("r")
return readout return readout
...@@ -221,8 +234,7 @@ class WeightAndSum(layers.Layer): ...@@ -221,8 +234,7 @@ class WeightAndSum(layers.Layer):
super(WeightAndSum, self).__init__() super(WeightAndSum, self).__init__()
self.in_feats = in_feats self.in_feats = in_feats
self.atom_weighting = tf.keras.Sequential( self.atom_weighting = tf.keras.Sequential(
layers.Dense(1), layers.Dense(1), layers.Activation(tf.nn.sigmoid)
layers.Activation(tf.nn.sigmoid)
) )
def call(self, g, feats): def call(self, g, feats):
...@@ -242,8 +254,8 @@ class WeightAndSum(layers.Layer): ...@@ -242,8 +254,8 @@ class WeightAndSum(layers.Layer):
Representations for B molecules Representations for B molecules
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = feats g.ndata["h"] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h']) g.ndata["w"] = self.atom_weighting(g.ndata["h"])
h_g_sum = sum_nodes(g, 'h', 'w') h_g_sum = sum_nodes(g, "h", "w")
return h_g_sum return h_g_sum
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
__all__ = ['HeteroGraphConv'] __all__ = ["HeteroGraphConv"]
class HeteroGraphConv(layers.Layer): class HeteroGraphConv(layers.Layer):
r"""A generic module for computing convolution on heterogeneous graphs. r"""A generic module for computing convolution on heterogeneous graphs.
...@@ -125,13 +126,16 @@ class HeteroGraphConv(layers.Layer): ...@@ -125,13 +126,16 @@ class HeteroGraphConv(layers.Layer):
mods : dict[str, nn.Module] mods : dict[str, nn.Module]
Modules associated with every edge types. Modules associated with every edge types.
""" """
def __init__(self, mods, aggregate='sum'):
def __init__(self, mods, aggregate="sum"):
super(HeteroGraphConv, self).__init__() super(HeteroGraphConv, self).__init__()
self.mods = mods self.mods = mods
# Do not break if graph has 0-in-degree nodes. # Do not break if graph has 0-in-degree nodes.
# Because there is no general rule to add self-loop for heterograph. # Because there is no general rule to add self-loop for heterograph.
for _, v in self.mods.items(): for _, v in self.mods.items():
set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None) set_allow_zero_in_degree_fn = getattr(
v, "set_allow_zero_in_degree", None
)
if callable(set_allow_zero_in_degree_fn): if callable(set_allow_zero_in_degree_fn):
set_allow_zero_in_degree_fn(True) set_allow_zero_in_degree_fn(True)
if isinstance(aggregate, str): if isinstance(aggregate, str):
...@@ -164,7 +168,7 @@ class HeteroGraphConv(layers.Layer): ...@@ -164,7 +168,7 @@ class HeteroGraphConv(layers.Layer):
mod_args = {} mod_args = {}
if mod_kwargs is None: if mod_kwargs is None:
mod_kwargs = {} mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes} outputs = {nty: [] for nty in g.dsttypes}
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
src_inputs, dst_inputs = inputs src_inputs, dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
...@@ -175,7 +179,8 @@ class HeteroGraphConv(layers.Layer): ...@@ -175,7 +179,8 @@ class HeteroGraphConv(layers.Layer):
rel_graph, rel_graph,
(src_inputs[stype], dst_inputs[dtype]), (src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
else: else:
for stype, etype, dtype in g.canonical_etypes: for stype, etype, dtype in g.canonical_etypes:
...@@ -186,7 +191,8 @@ class HeteroGraphConv(layers.Layer): ...@@ -186,7 +191,8 @@ class HeteroGraphConv(layers.Layer):
rel_graph, rel_graph,
(inputs[stype], inputs[dtype]), (inputs[stype], inputs[dtype]),
*mod_args.get(etype, ()), *mod_args.get(etype, ()),
**mod_kwargs.get(etype, {})) **mod_kwargs.get(etype, {})
)
outputs[dtype].append(dstdata) outputs[dtype].append(dstdata)
rsts = {} rsts = {}
for nty, alist in outputs.items(): for nty, alist in outputs.items():
...@@ -194,6 +200,7 @@ class HeteroGraphConv(layers.Layer): ...@@ -194,6 +200,7 @@ class HeteroGraphConv(layers.Layer):
rsts[nty] = self.agg_fn(alist, nty) rsts[nty] = self.agg_fn(alist, nty)
return rsts return rsts
def get_aggregate_fn(agg): def get_aggregate_fn(agg):
"""Internal function to get the aggregation function for node data """Internal function to get the aggregation function for node data
generated from different relations. generated from different relations.
...@@ -210,29 +217,35 @@ def get_aggregate_fn(agg): ...@@ -210,29 +217,35 @@ def get_aggregate_fn(agg):
Aggregator function that takes a list of tensors to aggregate Aggregator function that takes a list of tensors to aggregate
and returns one aggregated tensor. and returns one aggregated tensor.
""" """
if agg == 'sum': if agg == "sum":
fn = tf.reduce_sum fn = tf.reduce_sum
elif agg == 'max': elif agg == "max":
fn = tf.reduce_max fn = tf.reduce_max
elif agg == 'min': elif agg == "min":
fn = tf.reduce_min fn = tf.reduce_min
elif agg == 'mean': elif agg == "mean":
fn = tf.reduce_mean fn = tf.reduce_mean
elif agg == 'stack': elif agg == "stack":
fn = None # will not be called fn = None # will not be called
else: else:
raise DGLError('Invalid cross type aggregator. Must be one of ' raise DGLError(
'"sum", "max", "min", "mean" or "stack". But got "%s"' % agg) "Invalid cross type aggregator. Must be one of "
if agg == 'stack': '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
)
if agg == "stack":
def stack_agg(inputs, dsttype): # pylint: disable=unused-argument def stack_agg(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
return tf.stack(inputs, axis=1) return tf.stack(inputs, axis=1)
return stack_agg return stack_agg
else: else:
def aggfn(inputs, dsttype): # pylint: disable=unused-argument def aggfn(inputs, dsttype): # pylint: disable=unused-argument
if len(inputs) == 0: if len(inputs) == 0:
return None return None
stacked = tf.stack(inputs, axis=0) stacked = tf.stack(inputs, axis=0)
return fn(stacked, axis=0) return fn(stacked, axis=0)
return aggfn return aggfn
"""Utilities for tf NN package""" """Utilities for tf NN package"""
# pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
from tensorflow.keras import layers # pylint: disable=W0235
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers # pylint: disable=W0235
def matmul_maybe_select(A, B): def matmul_maybe_select(A, B):
...@@ -91,8 +91,7 @@ def bmm_maybe_select(A, B, index): ...@@ -91,8 +91,7 @@ def bmm_maybe_select(A, B, index):
class Identity(layers.Layer): class Identity(layers.Layer):
"""A placeholder identity operator that is argument-insensitive. """A placeholder identity operator that is argument-insensitive."""
"""
def call(self, x): def call(self, x):
"""Return input""" """Return input"""
......
"""dgl operator module.""" """dgl operator module."""
from .spmm import *
from .sddmm import *
from .edge_softmax import * from .edge_softmax import *
from .segment import *
from .gather_mm import * from .gather_mm import *
from .sddmm import *
from .segment import *
from .spmm import *
"""dgl edge_softmax operator module.""" """dgl edge_softmax operator module."""
from ..backend import astype
from ..backend import edge_softmax as edge_softmax_internal from ..backend import edge_softmax as edge_softmax_internal
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal
from ..backend import astype
from ..base import ALL, is_all from ..base import ALL, is_all
__all__ = ['edge_softmax'] __all__ = ["edge_softmax"]
def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): def edge_softmax(graph, logits, eids=ALL, norm_by="dst"):
r"""Compute softmax over weights of incoming edges for every node. r"""Compute softmax over weights of incoming edges for every node.
For a node :math:`i`, edge softmax is an operation that computes For a node :math:`i`, edge softmax is an operation that computes
...@@ -131,8 +131,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -131,8 +131,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
if not is_all(eids): if not is_all(eids):
eids = astype(eids, graph.idtype) eids = astype(eids, graph.idtype)
if graph._graph.number_of_etypes() == 1: if graph._graph.number_of_etypes() == 1:
return edge_softmax_internal(graph._graph, logits, return edge_softmax_internal(
eids=eids, norm_by=norm_by) graph._graph, logits, eids=eids, norm_by=norm_by
)
else: else:
logits_list = [None] * graph._graph.number_of_etypes() logits_list = [None] * graph._graph.number_of_etypes()
logits = {graph.to_canonical_etype(k): v for k, v in logits.items()} logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}
...@@ -140,8 +141,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -140,8 +141,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel] logits_list[etid] = logits[rel]
logits_tuple = tuple(logits_list) logits_tuple = tuple(logits_list)
score_tuple = edge_softmax_hetero_internal(graph._graph, score_tuple = edge_softmax_hetero_internal(
eids, norm_by, *logits_tuple) graph._graph, eids, norm_by, *logits_tuple
)
score = {} score = {}
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
......
"""dgl gather_mm operator module.""" """dgl gather_mm operator module."""
from .. import backend as F from .. import backend as F
__all__ = ['gather_mm'] __all__ = ["gather_mm"]
def gather_mm(a, b, *, idx_b): def gather_mm(a, b, *, idx_b):
r"""Gather data according to the given indices and perform matrix multiplication. r"""Gather data according to the given indices and perform matrix multiplication.
...@@ -31,12 +32,19 @@ def gather_mm(a, b, *, idx_b): ...@@ -31,12 +32,19 @@ def gather_mm(a, b, *, idx_b):
if N > 1000000 or D1 > 8 or D2 > 8: if N > 1000000 or D1 > 8 or D2 > 8:
# Use segment_mm for large workload # Use segment_mm for large workload
import torch import torch
sorted_idx_b, perm = torch.sort(idx_b) sorted_idx_b, perm = torch.sort(idx_b)
_, rev_perm = torch.sort(perm) _, rev_perm = torch.sort(perm)
sorted_a = torch.index_select(a, 0, perm) sorted_a = torch.index_select(a, 0, perm)
pos_l = torch.searchsorted(sorted_idx_b, torch.arange(R, device=a.device)) pos_l = torch.searchsorted(
pos_r = torch.cat([pos_l[1:], torch.tensor([len(idx_b)], device=a.device)]) sorted_idx_b, torch.arange(R, device=a.device)
)
pos_r = torch.cat(
[pos_l[1:], torch.tensor([len(idx_b)], device=a.device)]
)
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return torch.index_select(F.segment_mm(sorted_a, b, seglen), 0, rev_perm) return torch.index_select(
F.segment_mm(sorted_a, b, seglen), 0, rev_perm
)
else: else:
return F.gather_mm(a, b, None, idx_b) return F.gather_mm(a, b, None, idx_b)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment