Unverified Commit 63ac788f authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

auto-reformat-nn (#5319)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 0b3a447b
...@@ -6,7 +6,7 @@ from torch.nn import functional as F ...@@ -6,7 +6,7 @@ from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -94,20 +94,25 @@ class SAGEConv(nn.Module): ...@@ -94,20 +94,25 @@ class SAGEConv(nn.Module):
[ 0.5873, 1.6597], [ 0.5873, 1.6597],
[-0.2502, 2.8068]], grad_fn=<AddBackward0>) [-0.2502, 2.8068]], grad_fn=<AddBackward0>)
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
aggregator_type, in_feats,
feat_drop=0., out_feats,
bias=True, aggregator_type,
norm=None, feat_drop=0.0,
activation=None): bias=True,
norm=None,
activation=None,
):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'} valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types: if aggregator_type not in valid_aggre_types:
raise DGLError( raise DGLError(
'Invalid aggregator_type. Must be one of {}. ' "Invalid aggregator_type. Must be one of {}. "
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type) "But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
) )
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -118,14 +123,16 @@ class SAGEConv(nn.Module): ...@@ -118,14 +123,16 @@ class SAGEConv(nn.Module):
self.activation = activation self.activation = activation
# aggregator type: mean/pool/lstm/gcn # aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool': if aggregator_type == "pool":
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm': if aggregator_type == "lstm":
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) self.lstm = nn.LSTM(
self._in_src_feats, self._in_src_feats, batch_first=True
)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if aggregator_type != 'gcn': if aggregator_type != "gcn":
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
elif bias: elif bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats)) self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
...@@ -146,12 +153,12 @@ class SAGEConv(nn.Module): ...@@ -146,12 +153,12 @@ class SAGEConv(nn.Module):
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The LSTM module is using xavier initialization method for its weights. The LSTM module is using xavier initialization method for its weights.
""" """
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
if self._aggre_type == 'pool': if self._aggre_type == "pool":
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm': if self._aggre_type == "lstm":
self.lstm.reset_parameters() self.lstm.reset_parameters()
if self._aggre_type != 'gcn': if self._aggre_type != "gcn":
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
...@@ -160,12 +167,14 @@ class SAGEConv(nn.Module): ...@@ -160,12 +167,14 @@ class SAGEConv(nn.Module):
NOTE(zihao): lstm reducer with default schedule (degree bucketing) NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future. is slow, we could accelerate this with degree padding in the future.
""" """
m = nodes.mailbox['m'] # (B, L, D) m = nodes.mailbox["m"] # (B, L, D)
batch_size = m.shape[0] batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)), h = (
m.new_zeros((1, batch_size, self._in_src_feats))) m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)),
)
_, (rst, _) = self.lstm(m, h) _, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)} return {"neigh": rst.squeeze(0)}
def forward(self, graph, feat, edge_weight=None): def forward(self, graph, feat, edge_weight=None):
r""" r"""
...@@ -202,59 +211,74 @@ class SAGEConv(nn.Module): ...@@ -202,59 +211,74 @@ class SAGEConv(nn.Module):
else: else:
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
msg_fn = fn.copy_u('h', 'm') msg_fn = fn.copy_u("h", "m")
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata["_edge_weight"] = edge_weight
msg_fn = fn.u_mul_e('h', '_edge_weight', 'm') msg_fn = fn.u_mul_e("h", "_edge_weight", "m")
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges # Handle the case of graphs without edges
if graph.number_of_edges() == 0: if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros( graph.dstdata["neigh"] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst) feat_dst.shape[0], self._in_src_feats
).to(feat_dst)
# Determine whether to apply linear transformation before message passing A(XW) # Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats lin_before_mp = self._in_src_feats > self._out_feats
# Message Passing # Message Passing
if self._aggre_type == 'mean': if self._aggre_type == "mean":
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src graph.srcdata["h"] = (
graph.update_all(msg_fn, fn.mean('m', 'neigh')) self.fc_neigh(feat_src) if lin_before_mp else feat_src
h_neigh = graph.dstdata['neigh'] )
graph.update_all(msg_fn, fn.mean("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
if not lin_before_mp: if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh) h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'gcn': elif self._aggre_type == "gcn":
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src graph.srcdata["h"] = (
self.fc_neigh(feat_src) if lin_before_mp else feat_src
)
if isinstance(feat, tuple): # heterogeneous if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst graph.dstdata["h"] = (
self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
)
else: else:
if graph.is_block: if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()] graph.dstdata["h"] = graph.srcdata["h"][
: graph.num_dst_nodes()
]
else: else:
graph.dstdata['h'] = graph.srcdata['h'] graph.dstdata["h"] = graph.srcdata["h"]
graph.update_all(msg_fn, fn.sum('m', 'neigh')) graph.update_all(msg_fn, fn.sum("m", "neigh"))
# divide in_degrees # divide in_degrees
degs = graph.in_degrees().to(feat_dst) degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
degs.unsqueeze(-1) + 1
)
if not lin_before_mp: if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh) h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'pool': elif self._aggre_type == "pool":
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.srcdata["h"] = F.relu(self.fc_pool(feat_src))
graph.update_all(msg_fn, fn.max('m', 'neigh')) graph.update_all(msg_fn, fn.max("m", "neigh"))
h_neigh = self.fc_neigh(graph.dstdata['neigh']) h_neigh = self.fc_neigh(graph.dstdata["neigh"])
elif self._aggre_type == 'lstm': elif self._aggre_type == "lstm":
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(msg_fn, self._lstm_reducer) graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = self.fc_neigh(graph.dstdata['neigh']) h_neigh = self.fc_neigh(graph.dstdata["neigh"])
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) raise KeyError(
"Aggregator type {} not recognized.".format(
self._aggre_type
)
)
# GraphSAGE GCN does not require fc_self. # GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn': if self._aggre_type == "gcn":
rst = h_neigh rst = h_neigh
# add bias manually for GCN # add bias manually for GCN
if self.bias is not None: if self.bias is not None:
...@@ -262,7 +286,6 @@ class SAGEConv(nn.Module): ...@@ -262,7 +286,6 @@ class SAGEConv(nn.Module):
else: else:
rst = self.fc_self(h_self) + h_neigh rst = self.fc_self(h_self) + h_neigh
# activation # activation
if self.activation is not None: if self.activation is not None:
rst = self.activation(rst) rst = self.activation(rst)
......
...@@ -98,7 +98,6 @@ class TWIRLSConv(nn.Module): ...@@ -98,7 +98,6 @@ class TWIRLSConv(nn.Module):
attn_dropout=0.0, attn_dropout=0.0,
inp_dropout=0.0, inp_dropout=0.0,
): ):
super().__init__() super().__init__()
self.input_d = input_d self.input_d = input_d
self.output_d = output_d self.output_d = output_d
...@@ -542,7 +541,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module): ...@@ -542,7 +541,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
attn_dropout=0, attn_dropout=0,
precond=True, precond=True,
): ):
super().__init__() super().__init__()
self.d = d self.d = d
...@@ -596,7 +594,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module): ...@@ -596,7 +594,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
g = self.init_attn(g, Y, self.etas) g = self.init_attn(g, Y, self.etas)
for k, layer in enumerate(self.prop_layers): for k, layer in enumerate(self.prop_layers):
# do unfolding # do unfolding
Y = layer(g, Y, X, self.alp, self.lam) Y = layer(g, Y, X, self.alp, self.lam)
......
"""Torch Module for GNNExplainer""" """Torch Module for GNNExplainer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
from math import sqrt from math import sqrt
import torch import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from ....base import NID, EID from ....base import EID, NID
from ....subgraph import khop_in_subgraph from ....subgraph import khop_in_subgraph
__all__ = ['GNNExplainer', 'HeteroGNNExplainer'] __all__ = ["GNNExplainer", "HeteroGNNExplainer"]
class GNNExplainer(nn.Module): class GNNExplainer(nn.Module):
r"""GNNExplainer model from `GNNExplainer: Generating Explanations for r"""GNNExplainer model from `GNNExplainer: Generating Explanations for
...@@ -63,17 +65,19 @@ class GNNExplainer(nn.Module): ...@@ -63,17 +65,19 @@ class GNNExplainer(nn.Module):
If True, it will log the computation process, default to True. If True, it will log the computation process, default to True.
""" """
def __init__(self, def __init__(
model, self,
num_hops, model,
lr=0.01, num_hops,
num_epochs=100, lr=0.01,
*, num_epochs=100,
alpha1=0.005, *,
alpha2=1.0, alpha1=0.005,
beta1=1.0, alpha2=1.0,
beta2=0.1, beta1=1.0,
log=True): beta2=0.1,
log=True,
):
super(GNNExplainer, self).__init__() super(GNNExplainer, self).__init__()
self.model = model self.model = model
self.num_hops = num_hops self.num_hops = num_hops
...@@ -111,7 +115,7 @@ class GNNExplainer(nn.Module): ...@@ -111,7 +115,7 @@ class GNNExplainer(nn.Module):
std = 0.1 std = 0.1
feat_mask = nn.Parameter(torch.randn(1, feat_size, device=device) * std) feat_mask = nn.Parameter(torch.randn(1, feat_size, device=device) * std)
std = nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * num_nodes)) std = nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * num_nodes))
edge_mask = nn.Parameter(torch.randn(num_edges, device=device) * std) edge_mask = nn.Parameter(torch.randn(num_edges, device=device) * std)
return feat_mask, edge_mask return feat_mask, edge_mask
...@@ -142,16 +146,18 @@ class GNNExplainer(nn.Module): ...@@ -142,16 +146,18 @@ class GNNExplainer(nn.Module):
# Edge mask sparsity regularization # Edge mask sparsity regularization
loss = loss + self.alpha1 * torch.sum(edge_mask) loss = loss + self.alpha1 * torch.sum(edge_mask)
# Edge mask entropy regularization # Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \ ent = -edge_mask * torch.log(edge_mask + eps) - (
(1 - edge_mask) * torch.log(1 - edge_mask + eps) 1 - edge_mask
) * torch.log(1 - edge_mask + eps)
loss = loss + self.alpha2 * ent.mean() loss = loss + self.alpha2 * ent.mean()
feat_mask = feat_mask.sigmoid() feat_mask = feat_mask.sigmoid()
# Feature mask sparsity regularization # Feature mask sparsity regularization
loss = loss + self.beta1 * torch.mean(feat_mask) loss = loss + self.beta1 * torch.mean(feat_mask)
# Feature mask entropy regularization # Feature mask entropy regularization
ent = - feat_mask * torch.log(feat_mask + eps) - \ ent = -feat_mask * torch.log(feat_mask + eps) - (
(1 - feat_mask) * torch.log(1 - feat_mask + eps) 1 - feat_mask
) * torch.log(1 - feat_mask + eps)
loss = loss + self.beta2 * ent.mean() loss = loss + self.beta2 * ent.mean()
return loss return loss
...@@ -285,13 +291,14 @@ class GNNExplainer(nn.Module): ...@@ -285,13 +291,14 @@ class GNNExplainer(nn.Module):
if self.log: if self.log:
pbar = tqdm(total=self.num_epochs) pbar = tqdm(total=self.num_epochs)
pbar.set_description(f'Explain node {node_id}') pbar.set_description(f"Explain node {node_id}")
for _ in range(self.num_epochs): for _ in range(self.num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
h = feat * feat_mask.sigmoid() h = feat * feat_mask.sigmoid()
logits = self.model(graph=sg, feat=h, logits = self.model(
eweight=edge_mask.sigmoid(), **kwargs) graph=sg, feat=h, eweight=edge_mask.sigmoid(), **kwargs
)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[inverse_indices, pred_label[inverse_indices]] loss = -log_probs[inverse_indices, pred_label[inverse_indices]]
loss = self._loss_regularize(loss, feat_mask, edge_mask) loss = self._loss_regularize(loss, feat_mask, edge_mask)
...@@ -406,13 +413,14 @@ class GNNExplainer(nn.Module): ...@@ -406,13 +413,14 @@ class GNNExplainer(nn.Module):
if self.log: if self.log:
pbar = tqdm(total=self.num_epochs) pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain graph') pbar.set_description("Explain graph")
for _ in range(self.num_epochs): for _ in range(self.num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
h = feat * feat_mask.sigmoid() h = feat * feat_mask.sigmoid()
logits = self.model(graph=graph, feat=h, logits = self.model(
eweight=edge_mask.sigmoid(), **kwargs) graph=graph, feat=h, eweight=edge_mask.sigmoid(), **kwargs
)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[0, pred_label[0]] loss = -log_probs[0, pred_label[0]]
loss = self._loss_regularize(loss, feat_mask, edge_mask) loss = self._loss_regularize(loss, feat_mask, edge_mask)
...@@ -430,6 +438,7 @@ class GNNExplainer(nn.Module): ...@@ -430,6 +438,7 @@ class GNNExplainer(nn.Module):
return feat_mask, edge_mask return feat_mask, edge_mask
class HeteroGNNExplainer(nn.Module): class HeteroGNNExplainer(nn.Module):
r"""GNNExplainer model from `GNNExplainer: Generating Explanations for r"""GNNExplainer model from `GNNExplainer: Generating Explanations for
Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__, adapted for heterogeneous graphs Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__, adapted for heterogeneous graphs
...@@ -482,17 +491,19 @@ class HeteroGNNExplainer(nn.Module): ...@@ -482,17 +491,19 @@ class HeteroGNNExplainer(nn.Module):
If True, it will log the computation process, default to True. If True, it will log the computation process, default to True.
""" """
def __init__(self, def __init__(
model, self,
num_hops, model,
lr=0.01, num_hops,
num_epochs=100, lr=0.01,
*, num_epochs=100,
alpha1=0.005, *,
alpha2=1.0, alpha1=0.005,
beta1=1.0, alpha2=1.0,
beta2=0.1, beta1=1.0,
log=True): beta2=0.1,
log=True,
):
super(HeteroGNNExplainer, self).__init__() super(HeteroGNNExplainer, self).__init__()
self.model = model self.model = model
self.num_hops = num_hops self.num_hops = num_hops
...@@ -531,7 +542,9 @@ class HeteroGNNExplainer(nn.Module): ...@@ -531,7 +542,9 @@ class HeteroGNNExplainer(nn.Module):
std = 0.1 std = 0.1
for node_type, feature in feat.items(): for node_type, feature in feat.items():
_, feat_size = feature.size() _, feat_size = feature.size()
feat_masks[node_type] = nn.Parameter(torch.randn(1, feat_size, device=device) * std) feat_masks[node_type] = nn.Parameter(
torch.randn(1, feat_size, device=device) * std
)
edge_masks = {} edge_masks = {}
for canonical_etype in graph.canonical_etypes: for canonical_etype in graph.canonical_etypes:
...@@ -539,11 +552,12 @@ class HeteroGNNExplainer(nn.Module): ...@@ -539,11 +552,12 @@ class HeteroGNNExplainer(nn.Module):
dst_num_nodes = graph.num_nodes(canonical_etype[-1]) dst_num_nodes = graph.num_nodes(canonical_etype[-1])
num_nodes_sum = src_num_nodes + dst_num_nodes num_nodes_sum = src_num_nodes + dst_num_nodes
num_edges = graph.num_edges(canonical_etype) num_edges = graph.num_edges(canonical_etype)
std = nn.init.calculate_gain('relu') std = nn.init.calculate_gain("relu")
if num_nodes_sum > 0: if num_nodes_sum > 0:
std *= sqrt(2.0 / num_nodes_sum) std *= sqrt(2.0 / num_nodes_sum)
edge_masks[canonical_etype] = nn.Parameter( edge_masks[canonical_etype] = nn.Parameter(
torch.randn(num_edges, device=device) * std) torch.randn(num_edges, device=device) * std
)
return feat_masks, edge_masks return feat_masks, edge_masks
...@@ -576,8 +590,9 @@ class HeteroGNNExplainer(nn.Module): ...@@ -576,8 +590,9 @@ class HeteroGNNExplainer(nn.Module):
# Edge mask sparsity regularization # Edge mask sparsity regularization
loss = loss + self.alpha1 * torch.sum(edge_mask) loss = loss + self.alpha1 * torch.sum(edge_mask)
# Edge mask entropy regularization # Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \ ent = -edge_mask * torch.log(edge_mask + eps) - (
(1 - edge_mask) * torch.log(1 - edge_mask + eps) 1 - edge_mask
) * torch.log(1 - edge_mask + eps)
loss = loss + self.alpha2 * ent.mean() loss = loss + self.alpha2 * ent.mean()
for feat_mask in feat_masks.values(): for feat_mask in feat_masks.values():
...@@ -585,8 +600,9 @@ class HeteroGNNExplainer(nn.Module): ...@@ -585,8 +600,9 @@ class HeteroGNNExplainer(nn.Module):
# Feature mask sparsity regularization # Feature mask sparsity regularization
loss = loss + self.beta1 * torch.mean(feat_mask) loss = loss + self.beta1 * torch.mean(feat_mask)
# Feature mask entropy regularization # Feature mask entropy regularization
ent = - feat_mask * torch.log(feat_mask + eps) - \ ent = -feat_mask * torch.log(feat_mask + eps) - (
(1 - feat_mask) * torch.log(1 - feat_mask + eps) 1 - feat_mask
) * torch.log(1 - feat_mask + eps)
loss = loss + self.beta2 * ent.mean() loss = loss + self.beta2 * ent.mean()
return loss return loss
...@@ -715,7 +731,9 @@ class HeteroGNNExplainer(nn.Module): ...@@ -715,7 +731,9 @@ class HeteroGNNExplainer(nn.Module):
# Extract node-centered k-hop subgraph and # Extract node-centered k-hop subgraph and
# its associated node and edge features. # its associated node and edge features.
sg, inverse_indices = khop_in_subgraph(graph, {ntype: node_id}, self.num_hops) sg, inverse_indices = khop_in_subgraph(
graph, {ntype: node_id}, self.num_hops
)
inverse_indices = inverse_indices[ntype] inverse_indices = inverse_indices[ntype]
sg_nodes = sg.ndata[NID] sg_nodes = sg.ndata[NID]
sg_feat = {} sg_feat = {}
...@@ -735,7 +753,7 @@ class HeteroGNNExplainer(nn.Module): ...@@ -735,7 +753,7 @@ class HeteroGNNExplainer(nn.Module):
if self.log: if self.log:
pbar = tqdm(total=self.num_epochs) pbar = tqdm(total=self.num_epochs)
pbar.set_description(f'Explain node {node_id} with type {ntype}') pbar.set_description(f"Explain node {node_id} with type {ntype}")
for _ in range(self.num_epochs): for _ in range(self.num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
...@@ -745,8 +763,9 @@ class HeteroGNNExplainer(nn.Module): ...@@ -745,8 +763,9 @@ class HeteroGNNExplainer(nn.Module):
eweight = {} eweight = {}
for canonical_etype, canonical_etype_mask in edge_mask.items(): for canonical_etype, canonical_etype_mask in edge_mask.items():
eweight[canonical_etype] = canonical_etype_mask.sigmoid() eweight[canonical_etype] = canonical_etype_mask.sigmoid()
logits = self.model(graph=sg, feat=h, logits = self.model(graph=sg, feat=h, eweight=eweight, **kwargs)[
eweight=eweight, **kwargs)[ntype] ntype
]
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[inverse_indices, pred_label[inverse_indices]] loss = -log_probs[inverse_indices, pred_label[inverse_indices]]
loss = self._loss_regularize(loss, feat_mask, edge_mask) loss = self._loss_regularize(loss, feat_mask, edge_mask)
...@@ -760,10 +779,14 @@ class HeteroGNNExplainer(nn.Module): ...@@ -760,10 +779,14 @@ class HeteroGNNExplainer(nn.Module):
pbar.close() pbar.close()
for node_type in feat_mask: for node_type in feat_mask:
feat_mask[node_type] = feat_mask[node_type].detach().sigmoid().squeeze() feat_mask[node_type] = (
feat_mask[node_type].detach().sigmoid().squeeze()
)
for canonical_etype in edge_mask: for canonical_etype in edge_mask:
edge_mask[canonical_etype] = edge_mask[canonical_etype].detach().sigmoid() edge_mask[canonical_etype] = (
edge_mask[canonical_etype].detach().sigmoid()
)
return inverse_indices, sg, feat_mask, edge_mask return inverse_indices, sg, feat_mask, edge_mask
...@@ -882,7 +905,7 @@ class HeteroGNNExplainer(nn.Module): ...@@ -882,7 +905,7 @@ class HeteroGNNExplainer(nn.Module):
if self.log: if self.log:
pbar = tqdm(total=self.num_epochs) pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain graph') pbar.set_description("Explain graph")
for _ in range(self.num_epochs): for _ in range(self.num_epochs):
optimizer.zero_grad() optimizer.zero_grad()
...@@ -892,8 +915,7 @@ class HeteroGNNExplainer(nn.Module): ...@@ -892,8 +915,7 @@ class HeteroGNNExplainer(nn.Module):
eweight = {} eweight = {}
for canonical_etype, canonical_etype_mask in edge_mask.items(): for canonical_etype, canonical_etype_mask in edge_mask.items():
eweight[canonical_etype] = canonical_etype_mask.sigmoid() eweight[canonical_etype] = canonical_etype_mask.sigmoid()
logits = self.model(graph=graph, feat=h, logits = self.model(graph=graph, feat=h, eweight=eweight, **kwargs)
eweight=eweight, **kwargs)
log_probs = logits.log_softmax(dim=-1) log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[0, pred_label[0]] loss = -log_probs[0, pred_label[0]]
loss = self._loss_regularize(loss, feat_mask, edge_mask) loss = self._loss_regularize(loss, feat_mask, edge_mask)
...@@ -907,9 +929,13 @@ class HeteroGNNExplainer(nn.Module): ...@@ -907,9 +929,13 @@ class HeteroGNNExplainer(nn.Module):
pbar.close() pbar.close()
for node_type in feat_mask: for node_type in feat_mask:
feat_mask[node_type] = feat_mask[node_type].detach().sigmoid().squeeze() feat_mask[node_type] = (
feat_mask[node_type].detach().sigmoid().squeeze()
)
for canonical_etype in edge_mask: for canonical_etype in edge_mask:
edge_mask[canonical_etype] = edge_mask[canonical_etype].detach().sigmoid() edge_mask[canonical_etype] = (
edge_mask[canonical_etype].detach().sigmoid()
)
return feat_mask, edge_mask return feat_mask, edge_mask
...@@ -338,6 +338,7 @@ class RadiusGraph(nn.Module): ...@@ -338,6 +338,7 @@ class RadiusGraph(nn.Module):
[0.7000], [0.7000],
[0.2828]]) [0.2828]])
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
def __init__( def __init__(
self, self,
......
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
# pylint: disable= invalid-name # pylint: disable= invalid-name
import random import random
import torch import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from torch import nn
from torch.nn import init
from ...base import NID from ...base import NID
from ...convert import to_homogeneous, to_heterogeneous from ...convert import to_heterogeneous, to_homogeneous
from ...random import choice from ...random import choice
from ...sampling import random_walk from ...sampling import random_walk
__all__ = ['DeepWalk', 'MetaPath2Vec'] __all__ = ["DeepWalk", "MetaPath2Vec"]
class DeepWalk(nn.Module): class DeepWalk(nn.Module):
"""DeepWalk module from `DeepWalk: Online Learning of Social Representations """DeepWalk module from `DeepWalk: Online Learning of Social Representations
...@@ -81,19 +83,23 @@ class DeepWalk(nn.Module): ...@@ -81,19 +83,23 @@ class DeepWalk(nn.Module):
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy()) >>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy()) >>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
""" """
def __init__(self,
g, def __init__(
emb_dim=128, self,
walk_length=40, g,
window_size=5, emb_dim=128,
neg_weight=1, walk_length=40,
negative_size=5, window_size=5,
fast_neg=True, neg_weight=1,
sparse=True): negative_size=5,
fast_neg=True,
sparse=True,
):
super().__init__() super().__init__()
assert walk_length >= window_size + 1, \ assert (
f'Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}' walk_length >= window_size + 1
), f"Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}"
self.g = g self.g = g
self.emb_dim = emb_dim self.emb_dim = emb_dim
...@@ -172,7 +178,9 @@ class DeepWalk(nn.Module): ...@@ -172,7 +178,9 @@ class DeepWalk(nn.Module):
device = batch_walk.device device = batch_walk.device
batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim) batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim)
batch_context_embed = self.context_embed(batch_walk).view(-1, self.emb_dim) batch_context_embed = self.context_embed(batch_walk).view(
-1, self.emb_dim
)
batch_idx_list_offset = torch.arange(batch_size) * self.walk_length batch_idx_list_offset = torch.arange(batch_size) * self.walk_length
batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1) batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1)
...@@ -185,19 +193,23 @@ class DeepWalk(nn.Module): ...@@ -185,19 +193,23 @@ class DeepWalk(nn.Module):
pos_dst_emb = batch_context_embed[idx_list_dst] pos_dst_emb = batch_context_embed[idx_list_dst]
neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros( neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros(
self.negative_size).unsqueeze(0).to(device) self.negative_size
).unsqueeze(0).to(device)
neg_idx_list_src = neg_idx_list_src.view(-1) neg_idx_list_src = neg_idx_list_src.view(-1)
neg_src_emb = batch_node_embed[neg_idx_list_src.long()] neg_src_emb = batch_node_embed[neg_idx_list_src.long()]
if self.fast_neg: if self.fast_neg:
neg_idx_list_dst = list(range(batch_size * self.walk_length)) \ neg_idx_list_dst = list(range(batch_size * self.walk_length)) * (
* (self.negative_size * self.window_size * 2) self.negative_size * self.window_size * 2
)
random.shuffle(neg_idx_list_dst) random.shuffle(neg_idx_list_dst)
neg_idx_list_dst = neg_idx_list_dst[:len(neg_idx_list_src)] neg_idx_list_dst = neg_idx_list_dst[: len(neg_idx_list_src)]
neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device) neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device)
neg_dst_emb = batch_context_embed[neg_idx_list_dst] neg_dst_emb = batch_context_embed[neg_idx_list_dst]
else: else:
neg_dst = choice(self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob) neg_dst = choice(
self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob
)
neg_dst_emb = self.context_embed(neg_dst.to(device)) neg_dst_emb = self.context_embed(neg_dst.to(device))
pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1) pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1)
...@@ -206,10 +218,15 @@ class DeepWalk(nn.Module): ...@@ -206,10 +218,15 @@ class DeepWalk(nn.Module):
neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1) neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1)
neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.clamp(neg_score, max=6, min=-6)
neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight neg_score = (
torch.mean(-F.logsigmoid(-neg_score))
* self.negative_size
* self.neg_weight
)
return torch.mean(pos_score + neg_score) return torch.mean(pos_score + neg_score)
class MetaPath2Vec(nn.Module): class MetaPath2Vec(nn.Module):
r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for
Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__ Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__
...@@ -280,17 +297,21 @@ class MetaPath2Vec(nn.Module): ...@@ -280,17 +297,21 @@ class MetaPath2Vec(nn.Module):
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user']) >>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids) >>> user_emb = model.node_embed(user_nids)
""" """
def __init__(self,
g, def __init__(
metapath, self,
window_size, g,
emb_dim=128, metapath,
negative_size=5, window_size,
sparse=True): emb_dim=128,
negative_size=5,
sparse=True,
):
super().__init__() super().__init__()
assert len(metapath) + 1 >= window_size, \ assert (
f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}' len(metapath) + 1 >= window_size
), f"Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}"
self.hg = g self.hg = g
self.emb_dim = emb_dim self.emb_dim = emb_dim
...@@ -323,15 +344,21 @@ class MetaPath2Vec(nn.Module): ...@@ -323,15 +344,21 @@ class MetaPath2Vec(nn.Module):
traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath) traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
for tr in traces.cpu().numpy(): for tr in traces.cpu().numpy():
tr_nids = [ tr_nids = [
self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))] self.local_to_global_nid[node_metapath[i]][tr[i]]
for i in range(len(tr))
]
node_frequency[torch.LongTensor(tr_nids)] += 1 node_frequency[torch.LongTensor(tr_nids)] += 1
neg_prob = node_frequency.pow(0.75) neg_prob = node_frequency.pow(0.75)
self.neg_prob = neg_prob / neg_prob.sum() self.neg_prob = neg_prob / neg_prob.sum()
# center node embedding # center node embedding
self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse) self.node_embed = nn.Embedding(
self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse) num_nodes_total, self.emb_dim, sparse=sparse
)
self.context_embed = nn.Embedding(
num_nodes_total, self.emb_dim, sparse=sparse
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -357,21 +384,30 @@ class MetaPath2Vec(nn.Module): ...@@ -357,21 +384,30 @@ class MetaPath2Vec(nn.Module):
torch.Tensor torch.Tensor
Negative context nodes Negative context nodes
""" """
traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath) traces, _ = random_walk(
g=self.hg, nodes=indices, metapath=self.metapath
)
u_list = [] u_list = []
v_list = [] v_list = []
for tr in traces.cpu().numpy(): for tr in traces.cpu().numpy():
tr_nids = [ tr_nids = [
self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))] self.local_to_global_nid[self.node_metapath[i]][tr[i]]
for i in range(len(tr))
]
for i, u in enumerate(tr_nids): for i, u in enumerate(tr_nids):
for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]): for j, v in enumerate(
tr_nids[max(i - self.window_size, 0) : i + self.window_size]
):
if i == j: if i == j:
continue continue
u_list.append(u) u_list.append(u)
v_list.append(v) v_list.append(v)
neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size, neg_v = choice(
prob=self.neg_prob).reshape(len(u_list), self.negative_size) self.hg.num_nodes(),
size=len(u_list) * self.negative_size,
prob=self.neg_prob,
).reshape(len(u_list), self.negative_size)
return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v
......
...@@ -5,8 +5,7 @@ import torch as th ...@@ -5,8 +5,7 @@ import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ... import DGLGraph from ... import DGLGraph, function as fn
from ... import function as fn
from ...base import dgl_warning from ...base import dgl_warning
......
...@@ -4,8 +4,7 @@ import numpy as np ...@@ -4,8 +4,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from .... import broadcast_nodes from .... import broadcast_nodes, function as fn
from .... import function as fn
from ....base import dgl_warning from ....base import dgl_warning
......
"""Tensorflow modules for graph attention networks(GAT).""" """Tensorflow modules for graph attention networks(GAT)."""
import numpy as np
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
...@@ -134,42 +135,60 @@ class GATConv(layers.Layer): ...@@ -134,42 +135,60 @@ class GATConv(layers.Layer):
[ 0.5088224 , 0.10908248], [ 0.5088224 , 0.10908248],
[ 0.55670375, -0.6811229 ]]], dtype=float32)> [ 0.55670375, -0.6811229 ]]], dtype=float32)>
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
num_heads, in_feats,
feat_drop=0., out_feats,
attn_drop=0., num_heads,
negative_slope=0.2, feat_drop=0.0,
residual=False, attn_drop=0.0,
activation=None, negative_slope=0.2,
allow_zero_in_degree=False): residual=False,
activation=None,
allow_zero_in_degree=False,
):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt( xinit = tf.keras.initializers.VarianceScaling(
2), mode="fan_avg", distribution="untruncated_normal") scale=np.sqrt(2), mode="fan_avg", distribution="untruncated_normal"
)
if isinstance(in_feats, tuple): if isinstance(in_feats, tuple):
self.fc_src = layers.Dense( self.fc_src = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit) out_feats * num_heads, use_bias=False, kernel_initializer=xinit
)
self.fc_dst = layers.Dense( self.fc_dst = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit) out_feats * num_heads, use_bias=False, kernel_initializer=xinit
)
else: else:
self.fc = layers.Dense( self.fc = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit) out_feats * num_heads, use_bias=False, kernel_initializer=xinit
self.attn_l = tf.Variable(initial_value=xinit( )
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) self.attn_l = tf.Variable(
self.attn_r = tf.Variable(initial_value=xinit( initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True) shape=(1, num_heads, out_feats), dtype="float32"
),
trainable=True,
)
self.attn_r = tf.Variable(
initial_value=xinit(
shape=(1, num_heads, out_feats), dtype="float32"
),
trainable=True,
)
self.feat_drop = layers.Dropout(rate=feat_drop) self.feat_drop = layers.Dropout(rate=feat_drop)
self.attn_drop = layers.Dropout(rate=attn_drop) self.attn_drop = layers.Dropout(rate=attn_drop)
self.leaky_relu = layers.LeakyReLU(alpha=negative_slope) self.leaky_relu = layers.LeakyReLU(alpha=negative_slope)
if residual: if residual:
if in_feats != out_feats: if in_feats != out_feats:
self.res_fc = layers.Dense( self.res_fc = layers.Dense(
num_heads * out_feats, use_bias=False, kernel_initializer=xinit) num_heads * out_feats,
use_bias=False,
kernel_initializer=xinit,
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
...@@ -220,39 +239,47 @@ class GATConv(layers.Layer): ...@@ -220,39 +239,47 @@ class GATConv(layers.Layer):
""" """
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0: if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple): if isinstance(feat, tuple):
src_prefix_shape = tuple(feat[0].shape[:-1]) src_prefix_shape = tuple(feat[0].shape[:-1])
dst_prefix_shape = tuple(feat[1].shape[:-1]) dst_prefix_shape = tuple(feat[1].shape[:-1])
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = tf.reshape( feat_src = tf.reshape(
self.fc_src(h_src), self.fc_src(h_src),
src_prefix_shape + (self._num_heads, self._out_feats)) src_prefix_shape + (self._num_heads, self._out_feats),
)
feat_dst = tf.reshape( feat_dst = tf.reshape(
self.fc_dst(h_dst), self.fc_dst(h_dst),
dst_prefix_shape + (self._num_heads, self._out_feats)) dst_prefix_shape + (self._num_heads, self._out_feats),
)
else: else:
src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1]) src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1])
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape( feat_src = feat_dst = tf.reshape(
self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats)) self.fc(h_src),
src_prefix_shape + (self._num_heads, self._out_feats),
)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()] h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:] dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -265,27 +292,27 @@ class GATConv(layers.Layer): ...@@ -265,27 +292,27 @@ class GATConv(layers.Layer):
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True) el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)
er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True) er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({'er': er}) graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
fn.sum('m', 'ft')) rst = graph.dstdata["ft"]
rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = tf.reshape(self.res_fc( resval = tf.reshape(
h_dst), dst_prefix_shape + (-1, self._out_feats)) self.res_fc(h_dst), dst_prefix_shape + (-1, self._out_feats)
)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
"""Tensorflow modules for graph convolutions(GCN).""" """Tensorflow modules for graph convolutions(GCN)."""
import numpy as np
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
import numpy as np
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
...@@ -134,18 +135,23 @@ class GraphConv(layers.Layer): ...@@ -134,18 +135,23 @@ class GraphConv(layers.Layer):
[ 2.1405895, -0.2574358], [ 2.1405895, -0.2574358],
[ 1.3607183, -0.1636453]], dtype=float32)> [ 1.3607183, -0.1636453]], dtype=float32)>
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
norm='both', in_feats,
weight=True, out_feats,
bias=True, norm="both",
activation=None, weight=True,
allow_zero_in_degree=False): bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'): if norm not in ("none", "both", "right", "left"):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".' raise DGLError(
' But got "{}".'.format(norm)) 'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self._norm = norm self._norm = norm
...@@ -153,15 +159,21 @@ class GraphConv(layers.Layer): ...@@ -153,15 +159,21 @@ class GraphConv(layers.Layer):
if weight: if weight:
xinit = tf.keras.initializers.glorot_uniform() xinit = tf.keras.initializers.glorot_uniform()
self.weight = tf.Variable(initial_value=xinit( self.weight = tf.Variable(
shape=(in_feats, out_feats), dtype='float32'), trainable=True) initial_value=xinit(
shape=(in_feats, out_feats), dtype="float32"
),
trainable=True,
)
else: else:
self.weight = None self.weight = None
if bias: if bias:
zeroinit = tf.keras.initializers.zeros() zeroinit = tf.keras.initializers.zeros()
self.bias = tf.Variable(initial_value=zeroinit( self.bias = tf.Variable(
shape=(out_feats), dtype='float32'), trainable=True) initial_value=zeroinit(shape=(out_feats), dtype="float32"),
trainable=True,
)
else: else:
self.bias = None self.bias = None
...@@ -216,23 +228,27 @@ class GraphConv(layers.Layer): ...@@ -216,23 +228,27 @@ class GraphConv(layers.Layer):
""" """
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0: if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['both', 'left']: if self._norm in ["both", "left"]:
degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32), degs = tf.clip_by_value(
clip_value_min=1, tf.cast(graph.out_degrees(), tf.float32),
clip_value_max=np.inf) clip_value_min=1,
if self._norm == 'both': clip_value_max=np.inf,
)
if self._norm == "both":
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -242,9 +258,11 @@ class GraphConv(layers.Layer): ...@@ -242,9 +258,11 @@ class GraphConv(layers.Layer):
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
raise DGLError('External weight is provided while at the same time the' raise DGLError(
' module has defined its own weight parameter. Please' "External weight is provided while at the same time the"
' create the module with flag weight=False.') " module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else: else:
weight = self.weight weight = self.weight
...@@ -252,24 +270,28 @@ class GraphConv(layers.Layer): ...@@ -252,24 +270,28 @@ class GraphConv(layers.Layer):
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat_src = tf.matmul(feat_src, weight) feat_src = tf.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'), graph.update_all(
fn.sum(msg='m', out='h')) fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
rst = graph.dstdata['h'] )
rst = graph.dstdata["h"]
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'), graph.update_all(
fn.sum(msg='m', out='h')) fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
rst = graph.dstdata['h'] )
rst = graph.dstdata["h"]
if weight is not None: if weight is not None:
rst = tf.matmul(rst, weight) rst = tf.matmul(rst, weight)
if self._norm in ['both', 'right']: if self._norm in ["both", "right"]:
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), degs = tf.clip_by_value(
clip_value_min=1, tf.cast(graph.in_degrees(), tf.float32),
clip_value_max=np.inf) clip_value_min=1,
if self._norm == 'both': clip_value_max=np.inf,
)
if self._norm == "both":
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
...@@ -289,8 +311,8 @@ class GraphConv(layers.Layer): ...@@ -289,8 +311,8 @@ class GraphConv(layers.Layer):
"""Set the extra representation of the module, """Set the extra representation of the module,
which will come into effect when printing the model. which will come into effect when printing the model.
""" """
summary = 'in={_in_feats}, out={_out_feats}' summary = "in={_in_feats}, out={_out_feats}"
summary += ', normalization={_norm}' summary += ", normalization={_norm}"
if '_activation' in self.__dict__: if "_activation" in self.__dict__:
summary += ', activation={_activation}' summary += ", activation={_activation}"
return summary.format(**self.__dict__) return summary.format(**self.__dict__)
...@@ -5,7 +5,7 @@ from tensorflow.keras import layers ...@@ -5,7 +5,7 @@ from tensorflow.keras import layers
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(layers.Layer): class SAGEConv(layers.Layer):
...@@ -88,20 +88,25 @@ class SAGEConv(layers.Layer): ...@@ -88,20 +88,25 @@ class SAGEConv(layers.Layer):
[ 0.3221837 , -0.29876417], [ 0.3221837 , -0.29876417],
[-0.63356155, 0.09390211]], dtype=float32)> [-0.63356155, 0.09390211]], dtype=float32)>
""" """
def __init__(self,
in_feats, def __init__(
out_feats, self,
aggregator_type, in_feats,
feat_drop=0., out_feats,
bias=True, aggregator_type,
norm=None, feat_drop=0.0,
activation=None): bias=True,
norm=None,
activation=None,
):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'} valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types: if aggregator_type not in valid_aggre_types:
raise DGLError( raise DGLError(
'Invalid aggregator_type. Must be one of {}. ' "Invalid aggregator_type. Must be one of {}. "
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type) "But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
) )
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
...@@ -111,11 +116,11 @@ class SAGEConv(layers.Layer): ...@@ -111,11 +116,11 @@ class SAGEConv(layers.Layer):
self.feat_drop = layers.Dropout(feat_drop) self.feat_drop = layers.Dropout(feat_drop)
self.activation = activation self.activation = activation
# aggregator type: mean/pool/lstm/gcn # aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool': if aggregator_type == "pool":
self.fc_pool = layers.Dense(self._in_src_feats) self.fc_pool = layers.Dense(self._in_src_feats)
if aggregator_type == 'lstm': if aggregator_type == "lstm":
self.lstm = layers.LSTM(units=self._in_src_feats) self.lstm = layers.LSTM(units=self._in_src_feats)
if aggregator_type != 'gcn': if aggregator_type != "gcn":
self.fc_self = layers.Dense(out_feats, use_bias=bias) self.fc_self = layers.Dense(out_feats, use_bias=bias)
self.fc_neigh = layers.Dense(out_feats, use_bias=bias) self.fc_neigh = layers.Dense(out_feats, use_bias=bias)
...@@ -124,9 +129,9 @@ class SAGEConv(layers.Layer): ...@@ -124,9 +129,9 @@ class SAGEConv(layers.Layer):
NOTE(zihao): lstm reducer with default schedule (degree bucketing) NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future. is slow, we could accelerate this with degree padding in the future.
""" """
m = nodes.mailbox['m'] # (B, L, D) m = nodes.mailbox["m"] # (B, L, D)
rst = self.lstm(m) rst = self.lstm(m)
return {'neigh': rst} return {"neigh": rst}
def call(self, graph, feat): def call(self, graph, feat):
r"""Compute GraphSAGE layer. r"""Compute GraphSAGE layer.
...@@ -155,41 +160,47 @@ class SAGEConv(layers.Layer): ...@@ -155,41 +160,47 @@ class SAGEConv(layers.Layer):
else: else:
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_self = feat_dst h_self = feat_dst
# Handle the case of graphs without edges # Handle the case of graphs without edges
if graph.number_of_edges() == 0: if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = tf.cast(tf.zeros( graph.dstdata["neigh"] = tf.cast(
(graph.number_of_dst_nodes(), self._in_src_feats)), tf.float32) tf.zeros((graph.number_of_dst_nodes(), self._in_src_feats)),
tf.float32,
if self._aggre_type == 'mean': )
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) if self._aggre_type == "mean":
h_neigh = graph.dstdata['neigh'] graph.srcdata["h"] = feat_src
elif self._aggre_type == 'gcn': graph.update_all(fn.copy_u("h", "m"), fn.mean("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "gcn":
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata["h"] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata["h"] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "neigh"))
# divide in_degrees # divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32) degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h'] h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
) / (tf.expand_dims(degs, -1) + 1) tf.expand_dims(degs, -1) + 1
elif self._aggre_type == 'pool': )
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src)) elif self._aggre_type == "pool":
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) graph.srcdata["h"] = tf.nn.relu(self.fc_pool(feat_src))
h_neigh = graph.dstdata['neigh'] graph.update_all(fn.copy_u("h", "m"), fn.max("m", "neigh"))
elif self._aggre_type == 'lstm': h_neigh = graph.dstdata["neigh"]
graph.srcdata['h'] = feat_src elif self._aggre_type == "lstm":
graph.update_all(fn.copy_u('h', 'm'), self._lstm_reducer) graph.srcdata["h"] = feat_src
h_neigh = graph.dstdata['neigh'] graph.update_all(fn.copy_u("h", "m"), self._lstm_reducer)
h_neigh = graph.dstdata["neigh"]
else: else:
raise KeyError( raise KeyError(
'Aggregator type {} not recognized.'.format(self._aggre_type)) "Aggregator type {} not recognized.".format(
self._aggre_type
)
)
# GraphSAGE GCN does not require fc_self. # GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn': if self._aggre_type == "gcn":
rst = self.fc_neigh(h_neigh) rst = self.fc_neigh(h_neigh)
else: else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment