Unverified Commit 63ac788f authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

auto-reformat-nn (#5319)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 0b3a447b
......@@ -6,7 +6,7 @@ from torch.nn import functional as F
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(nn.Module):
......@@ -94,20 +94,25 @@ class SAGEConv(nn.Module):
[ 0.5873, 1.6597],
[-0.2502, 2.8068]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
def __init__(
self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.0,
bias=True,
norm=None,
activation=None,
):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
"Invalid aggregator_type. Must be one of {}. "
"But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -118,14 +123,16 @@ class SAGEConv(nn.Module):
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
if aggregator_type == "pool":
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type == "lstm":
self.lstm = nn.LSTM(
self._in_src_feats, self._in_src_feats, batch_first=True
)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if aggregator_type != 'gcn':
if aggregator_type != "gcn":
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
elif bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
......@@ -146,12 +153,12 @@ class SAGEConv(nn.Module):
The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
The LSTM module is using xavier initialization method for its weights.
"""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
gain = nn.init.calculate_gain("relu")
if self._aggre_type == "pool":
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
if self._aggre_type == "lstm":
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
if self._aggre_type != "gcn":
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
......@@ -160,12 +167,14 @@ class SAGEConv(nn.Module):
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
m = nodes.mailbox["m"] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
h = (
m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)),
)
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
return {"neigh": rst.squeeze(0)}
def forward(self, graph, feat, edge_weight=None):
r"""
......@@ -202,59 +211,74 @@ class SAGEConv(nn.Module):
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
msg_fn = fn.copy_u('h', 'm')
feat_dst = feat_src[: graph.number_of_dst_nodes()]
msg_fn = fn.copy_u("h", "m")
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')
graph.edata["_edge_weight"] = edge_weight
msg_fn = fn.u_mul_e("h", "_edge_weight", "m")
h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst)
graph.dstdata["neigh"] = torch.zeros(
feat_dst.shape[0], self._in_src_feats
).to(feat_dst)
# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats
# Message Passing
if self._aggre_type == 'mean':
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.update_all(msg_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
if self._aggre_type == "mean":
graph.srcdata["h"] = (
self.fc_neigh(feat_src) if lin_before_mp else feat_src
)
graph.update_all(msg_fn, fn.mean("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'gcn':
elif self._aggre_type == "gcn":
check_eq_shape(feat)
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.srcdata["h"] = (
self.fc_neigh(feat_src) if lin_before_mp else feat_src
)
if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
graph.dstdata["h"] = (
self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
)
else:
if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
graph.dstdata["h"] = graph.srcdata["h"][
: graph.num_dst_nodes()
]
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
graph.dstdata["h"] = graph.srcdata["h"]
graph.update_all(msg_fn, fn.sum("m", "neigh"))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
degs.unsqueeze(-1) + 1
)
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(msg_fn, fn.max('m', 'neigh'))
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
elif self._aggre_type == "pool":
graph.srcdata["h"] = F.relu(self.fc_pool(feat_src))
graph.update_all(msg_fn, fn.max("m", "neigh"))
h_neigh = self.fc_neigh(graph.dstdata["neigh"])
elif self._aggre_type == "lstm":
graph.srcdata["h"] = feat_src
graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
h_neigh = self.fc_neigh(graph.dstdata["neigh"])
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
raise KeyError(
"Aggregator type {} not recognized.".format(
self._aggre_type
)
)
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
if self._aggre_type == "gcn":
rst = h_neigh
# add bias manually for GCN
if self.bias is not None:
......@@ -262,7 +286,6 @@ class SAGEConv(nn.Module):
else:
rst = self.fc_self(h_self) + h_neigh
# activation
if self.activation is not None:
rst = self.activation(rst)
......
......@@ -98,7 +98,6 @@ class TWIRLSConv(nn.Module):
attn_dropout=0.0,
inp_dropout=0.0,
):
super().__init__()
self.input_d = input_d
self.output_d = output_d
......@@ -542,7 +541,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
attn_dropout=0,
precond=True,
):
super().__init__()
self.d = d
......@@ -596,7 +594,6 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
g = self.init_attn(g, Y, self.etas)
for k, layer in enumerate(self.prop_layers):
# do unfolding
Y = layer(g, Y, X, self.alp, self.lam)
......
"""Torch Module for GNNExplainer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from math import sqrt
import torch
from torch import nn
from tqdm import tqdm
from ....base import NID, EID
from ....base import EID, NID
from ....subgraph import khop_in_subgraph
__all__ = ['GNNExplainer', 'HeteroGNNExplainer']
__all__ = ["GNNExplainer", "HeteroGNNExplainer"]
class GNNExplainer(nn.Module):
r"""GNNExplainer model from `GNNExplainer: Generating Explanations for
......@@ -63,17 +65,19 @@ class GNNExplainer(nn.Module):
If True, it will log the computation process, default to True.
"""
def __init__(self,
model,
num_hops,
lr=0.01,
num_epochs=100,
*,
alpha1=0.005,
alpha2=1.0,
beta1=1.0,
beta2=0.1,
log=True):
def __init__(
self,
model,
num_hops,
lr=0.01,
num_epochs=100,
*,
alpha1=0.005,
alpha2=1.0,
beta1=1.0,
beta2=0.1,
log=True,
):
super(GNNExplainer, self).__init__()
self.model = model
self.num_hops = num_hops
......@@ -111,7 +115,7 @@ class GNNExplainer(nn.Module):
std = 0.1
feat_mask = nn.Parameter(torch.randn(1, feat_size, device=device) * std)
std = nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * num_nodes))
std = nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * num_nodes))
edge_mask = nn.Parameter(torch.randn(num_edges, device=device) * std)
return feat_mask, edge_mask
......@@ -142,16 +146,18 @@ class GNNExplainer(nn.Module):
# Edge mask sparsity regularization
loss = loss + self.alpha1 * torch.sum(edge_mask)
# Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \
(1 - edge_mask) * torch.log(1 - edge_mask + eps)
ent = -edge_mask * torch.log(edge_mask + eps) - (
1 - edge_mask
) * torch.log(1 - edge_mask + eps)
loss = loss + self.alpha2 * ent.mean()
feat_mask = feat_mask.sigmoid()
# Feature mask sparsity regularization
loss = loss + self.beta1 * torch.mean(feat_mask)
# Feature mask entropy regularization
ent = - feat_mask * torch.log(feat_mask + eps) - \
(1 - feat_mask) * torch.log(1 - feat_mask + eps)
ent = -feat_mask * torch.log(feat_mask + eps) - (
1 - feat_mask
) * torch.log(1 - feat_mask + eps)
loss = loss + self.beta2 * ent.mean()
return loss
......@@ -285,13 +291,14 @@ class GNNExplainer(nn.Module):
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description(f'Explain node {node_id}')
pbar.set_description(f"Explain node {node_id}")
for _ in range(self.num_epochs):
optimizer.zero_grad()
h = feat * feat_mask.sigmoid()
logits = self.model(graph=sg, feat=h,
eweight=edge_mask.sigmoid(), **kwargs)
logits = self.model(
graph=sg, feat=h, eweight=edge_mask.sigmoid(), **kwargs
)
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[inverse_indices, pred_label[inverse_indices]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
......@@ -406,13 +413,14 @@ class GNNExplainer(nn.Module):
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain graph')
pbar.set_description("Explain graph")
for _ in range(self.num_epochs):
optimizer.zero_grad()
h = feat * feat_mask.sigmoid()
logits = self.model(graph=graph, feat=h,
eweight=edge_mask.sigmoid(), **kwargs)
logits = self.model(
graph=graph, feat=h, eweight=edge_mask.sigmoid(), **kwargs
)
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[0, pred_label[0]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
......@@ -430,6 +438,7 @@ class GNNExplainer(nn.Module):
return feat_mask, edge_mask
class HeteroGNNExplainer(nn.Module):
r"""GNNExplainer model from `GNNExplainer: Generating Explanations for
Graph Neural Networks <https://arxiv.org/abs/1903.03894>`__, adapted for heterogeneous graphs
......@@ -482,17 +491,19 @@ class HeteroGNNExplainer(nn.Module):
If True, it will log the computation process, default to True.
"""
def __init__(self,
model,
num_hops,
lr=0.01,
num_epochs=100,
*,
alpha1=0.005,
alpha2=1.0,
beta1=1.0,
beta2=0.1,
log=True):
def __init__(
self,
model,
num_hops,
lr=0.01,
num_epochs=100,
*,
alpha1=0.005,
alpha2=1.0,
beta1=1.0,
beta2=0.1,
log=True,
):
super(HeteroGNNExplainer, self).__init__()
self.model = model
self.num_hops = num_hops
......@@ -531,7 +542,9 @@ class HeteroGNNExplainer(nn.Module):
std = 0.1
for node_type, feature in feat.items():
_, feat_size = feature.size()
feat_masks[node_type] = nn.Parameter(torch.randn(1, feat_size, device=device) * std)
feat_masks[node_type] = nn.Parameter(
torch.randn(1, feat_size, device=device) * std
)
edge_masks = {}
for canonical_etype in graph.canonical_etypes:
......@@ -539,11 +552,12 @@ class HeteroGNNExplainer(nn.Module):
dst_num_nodes = graph.num_nodes(canonical_etype[-1])
num_nodes_sum = src_num_nodes + dst_num_nodes
num_edges = graph.num_edges(canonical_etype)
std = nn.init.calculate_gain('relu')
std = nn.init.calculate_gain("relu")
if num_nodes_sum > 0:
std *= sqrt(2.0 / num_nodes_sum)
edge_masks[canonical_etype] = nn.Parameter(
torch.randn(num_edges, device=device) * std)
torch.randn(num_edges, device=device) * std
)
return feat_masks, edge_masks
......@@ -576,8 +590,9 @@ class HeteroGNNExplainer(nn.Module):
# Edge mask sparsity regularization
loss = loss + self.alpha1 * torch.sum(edge_mask)
# Edge mask entropy regularization
ent = - edge_mask * torch.log(edge_mask + eps) - \
(1 - edge_mask) * torch.log(1 - edge_mask + eps)
ent = -edge_mask * torch.log(edge_mask + eps) - (
1 - edge_mask
) * torch.log(1 - edge_mask + eps)
loss = loss + self.alpha2 * ent.mean()
for feat_mask in feat_masks.values():
......@@ -585,8 +600,9 @@ class HeteroGNNExplainer(nn.Module):
# Feature mask sparsity regularization
loss = loss + self.beta1 * torch.mean(feat_mask)
# Feature mask entropy regularization
ent = - feat_mask * torch.log(feat_mask + eps) - \
(1 - feat_mask) * torch.log(1 - feat_mask + eps)
ent = -feat_mask * torch.log(feat_mask + eps) - (
1 - feat_mask
) * torch.log(1 - feat_mask + eps)
loss = loss + self.beta2 * ent.mean()
return loss
......@@ -715,7 +731,9 @@ class HeteroGNNExplainer(nn.Module):
# Extract node-centered k-hop subgraph and
# its associated node and edge features.
sg, inverse_indices = khop_in_subgraph(graph, {ntype: node_id}, self.num_hops)
sg, inverse_indices = khop_in_subgraph(
graph, {ntype: node_id}, self.num_hops
)
inverse_indices = inverse_indices[ntype]
sg_nodes = sg.ndata[NID]
sg_feat = {}
......@@ -735,7 +753,7 @@ class HeteroGNNExplainer(nn.Module):
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description(f'Explain node {node_id} with type {ntype}')
pbar.set_description(f"Explain node {node_id} with type {ntype}")
for _ in range(self.num_epochs):
optimizer.zero_grad()
......@@ -745,8 +763,9 @@ class HeteroGNNExplainer(nn.Module):
eweight = {}
for canonical_etype, canonical_etype_mask in edge_mask.items():
eweight[canonical_etype] = canonical_etype_mask.sigmoid()
logits = self.model(graph=sg, feat=h,
eweight=eweight, **kwargs)[ntype]
logits = self.model(graph=sg, feat=h, eweight=eweight, **kwargs)[
ntype
]
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[inverse_indices, pred_label[inverse_indices]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
......@@ -760,10 +779,14 @@ class HeteroGNNExplainer(nn.Module):
pbar.close()
for node_type in feat_mask:
feat_mask[node_type] = feat_mask[node_type].detach().sigmoid().squeeze()
feat_mask[node_type] = (
feat_mask[node_type].detach().sigmoid().squeeze()
)
for canonical_etype in edge_mask:
edge_mask[canonical_etype] = edge_mask[canonical_etype].detach().sigmoid()
edge_mask[canonical_etype] = (
edge_mask[canonical_etype].detach().sigmoid()
)
return inverse_indices, sg, feat_mask, edge_mask
......@@ -882,7 +905,7 @@ class HeteroGNNExplainer(nn.Module):
if self.log:
pbar = tqdm(total=self.num_epochs)
pbar.set_description('Explain graph')
pbar.set_description("Explain graph")
for _ in range(self.num_epochs):
optimizer.zero_grad()
......@@ -892,8 +915,7 @@ class HeteroGNNExplainer(nn.Module):
eweight = {}
for canonical_etype, canonical_etype_mask in edge_mask.items():
eweight[canonical_etype] = canonical_etype_mask.sigmoid()
logits = self.model(graph=graph, feat=h,
eweight=eweight, **kwargs)
logits = self.model(graph=graph, feat=h, eweight=eweight, **kwargs)
log_probs = logits.log_softmax(dim=-1)
loss = -log_probs[0, pred_label[0]]
loss = self._loss_regularize(loss, feat_mask, edge_mask)
......@@ -907,9 +929,13 @@ class HeteroGNNExplainer(nn.Module):
pbar.close()
for node_type in feat_mask:
feat_mask[node_type] = feat_mask[node_type].detach().sigmoid().squeeze()
feat_mask[node_type] = (
feat_mask[node_type].detach().sigmoid().squeeze()
)
for canonical_etype in edge_mask:
edge_mask[canonical_etype] = edge_mask[canonical_etype].detach().sigmoid()
edge_mask[canonical_etype] = (
edge_mask[canonical_etype].detach().sigmoid()
)
return feat_mask, edge_mask
......@@ -338,6 +338,7 @@ class RadiusGraph(nn.Module):
[0.7000],
[0.2828]])
"""
# pylint: disable=invalid-name
def __init__(
self,
......
......@@ -2,18 +2,20 @@
# pylint: disable= invalid-name
import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import tqdm
from torch import nn
from torch.nn import init
from ...base import NID
from ...convert import to_homogeneous, to_heterogeneous
from ...convert import to_heterogeneous, to_homogeneous
from ...random import choice
from ...sampling import random_walk
__all__ = ['DeepWalk', 'MetaPath2Vec']
__all__ = ["DeepWalk", "MetaPath2Vec"]
class DeepWalk(nn.Module):
"""DeepWalk module from `DeepWalk: Online Learning of Social Representations
......@@ -81,19 +83,23 @@ class DeepWalk(nn.Module):
>>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy())
>>> clf.score(X[test_mask].numpy(), y[test_mask].numpy())
"""
def __init__(self,
g,
emb_dim=128,
walk_length=40,
window_size=5,
neg_weight=1,
negative_size=5,
fast_neg=True,
sparse=True):
def __init__(
self,
g,
emb_dim=128,
walk_length=40,
window_size=5,
neg_weight=1,
negative_size=5,
fast_neg=True,
sparse=True,
):
super().__init__()
assert walk_length >= window_size + 1, \
f'Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}'
assert (
walk_length >= window_size + 1
), f"Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}"
self.g = g
self.emb_dim = emb_dim
......@@ -172,7 +178,9 @@ class DeepWalk(nn.Module):
device = batch_walk.device
batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim)
batch_context_embed = self.context_embed(batch_walk).view(-1, self.emb_dim)
batch_context_embed = self.context_embed(batch_walk).view(
-1, self.emb_dim
)
batch_idx_list_offset = torch.arange(batch_size) * self.walk_length
batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1)
......@@ -185,19 +193,23 @@ class DeepWalk(nn.Module):
pos_dst_emb = batch_context_embed[idx_list_dst]
neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros(
self.negative_size).unsqueeze(0).to(device)
self.negative_size
).unsqueeze(0).to(device)
neg_idx_list_src = neg_idx_list_src.view(-1)
neg_src_emb = batch_node_embed[neg_idx_list_src.long()]
if self.fast_neg:
neg_idx_list_dst = list(range(batch_size * self.walk_length)) \
* (self.negative_size * self.window_size * 2)
neg_idx_list_dst = list(range(batch_size * self.walk_length)) * (
self.negative_size * self.window_size * 2
)
random.shuffle(neg_idx_list_dst)
neg_idx_list_dst = neg_idx_list_dst[:len(neg_idx_list_src)]
neg_idx_list_dst = neg_idx_list_dst[: len(neg_idx_list_src)]
neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device)
neg_dst_emb = batch_context_embed[neg_idx_list_dst]
else:
neg_dst = choice(self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob)
neg_dst = choice(
self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob
)
neg_dst_emb = self.context_embed(neg_dst.to(device))
pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1)
......@@ -206,10 +218,15 @@ class DeepWalk(nn.Module):
neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1)
neg_score = torch.clamp(neg_score, max=6, min=-6)
neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight
neg_score = (
torch.mean(-F.logsigmoid(-neg_score))
* self.negative_size
* self.neg_weight
)
return torch.mean(pos_score + neg_score)
class MetaPath2Vec(nn.Module):
r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for
Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__
......@@ -280,17 +297,21 @@ class MetaPath2Vec(nn.Module):
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids)
"""
def __init__(self,
g,
metapath,
window_size,
emb_dim=128,
negative_size=5,
sparse=True):
def __init__(
self,
g,
metapath,
window_size,
emb_dim=128,
negative_size=5,
sparse=True,
):
super().__init__()
assert len(metapath) + 1 >= window_size, \
f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}'
assert (
len(metapath) + 1 >= window_size
), f"Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}"
self.hg = g
self.emb_dim = emb_dim
......@@ -323,15 +344,21 @@ class MetaPath2Vec(nn.Module):
traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))]
self.local_to_global_nid[node_metapath[i]][tr[i]]
for i in range(len(tr))
]
node_frequency[torch.LongTensor(tr_nids)] += 1
neg_prob = node_frequency.pow(0.75)
self.neg_prob = neg_prob / neg_prob.sum()
# center node embedding
self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
self.node_embed = nn.Embedding(
num_nodes_total, self.emb_dim, sparse=sparse
)
self.context_embed = nn.Embedding(
num_nodes_total, self.emb_dim, sparse=sparse
)
self.reset_parameters()
def reset_parameters(self):
......@@ -357,21 +384,30 @@ class MetaPath2Vec(nn.Module):
torch.Tensor
Negative context nodes
"""
traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath)
traces, _ = random_walk(
g=self.hg, nodes=indices, metapath=self.metapath
)
u_list = []
v_list = []
for tr in traces.cpu().numpy():
tr_nids = [
self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))]
self.local_to_global_nid[self.node_metapath[i]][tr[i]]
for i in range(len(tr))
]
for i, u in enumerate(tr_nids):
for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]):
for j, v in enumerate(
tr_nids[max(i - self.window_size, 0) : i + self.window_size]
):
if i == j:
continue
u_list.append(u)
v_list.append(v)
neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size,
prob=self.neg_prob).reshape(len(u_list), self.negative_size)
neg_v = choice(
self.hg.num_nodes(),
size=len(u_list) * self.negative_size,
prob=self.neg_prob,
).reshape(len(u_list), self.negative_size)
return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v
......
......@@ -5,8 +5,7 @@ import torch as th
import torch.nn.functional as F
from torch import nn
from ... import DGLGraph
from ... import function as fn
from ... import DGLGraph, function as fn
from ...base import dgl_warning
......
......@@ -4,8 +4,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from .... import broadcast_nodes
from .... import function as fn
from .... import broadcast_nodes, function as fn
from ....base import dgl_warning
......
"""Tensorflow modules for graph attention networks(GAT)."""
import numpy as np
# pylint: disable= no-member, arguments-differ, invalid-name
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from .... import function as fn
from ....base import DGLError
......@@ -134,42 +135,60 @@ class GATConv(layers.Layer):
[ 0.5088224 , 0.10908248],
[ 0.55670375, -0.6811229 ]]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
):
super(GATConv, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
xinit = tf.keras.initializers.VarianceScaling(scale=np.sqrt(
2), mode="fan_avg", distribution="untruncated_normal")
xinit = tf.keras.initializers.VarianceScaling(
scale=np.sqrt(2), mode="fan_avg", distribution="untruncated_normal"
)
if isinstance(in_feats, tuple):
self.fc_src = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
out_feats * num_heads, use_bias=False, kernel_initializer=xinit
)
self.fc_dst = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
out_feats * num_heads, use_bias=False, kernel_initializer=xinit
)
else:
self.fc = layers.Dense(
out_feats * num_heads, use_bias=False, kernel_initializer=xinit)
self.attn_l = tf.Variable(initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True)
self.attn_r = tf.Variable(initial_value=xinit(
shape=(1, num_heads, out_feats), dtype='float32'), trainable=True)
out_feats * num_heads, use_bias=False, kernel_initializer=xinit
)
self.attn_l = tf.Variable(
initial_value=xinit(
shape=(1, num_heads, out_feats), dtype="float32"
),
trainable=True,
)
self.attn_r = tf.Variable(
initial_value=xinit(
shape=(1, num_heads, out_feats), dtype="float32"
),
trainable=True,
)
self.feat_drop = layers.Dropout(rate=feat_drop)
self.attn_drop = layers.Dropout(rate=attn_drop)
self.leaky_relu = layers.LeakyReLU(alpha=negative_slope)
if residual:
if in_feats != out_feats:
self.res_fc = layers.Dense(
num_heads * out_feats, use_bias=False, kernel_initializer=xinit)
num_heads * out_feats,
use_bias=False,
kernel_initializer=xinit,
)
else:
self.res_fc = Identity()
else:
......@@ -220,39 +239,47 @@ class GATConv(layers.Layer):
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
src_prefix_shape = tuple(feat[0].shape[:-1])
dst_prefix_shape = tuple(feat[1].shape[:-1])
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = tf.reshape(
self.fc_src(h_src),
src_prefix_shape + (self._num_heads, self._out_feats))
src_prefix_shape + (self._num_heads, self._out_feats),
)
feat_dst = tf.reshape(
self.fc_dst(h_dst),
dst_prefix_shape + (self._num_heads, self._out_feats))
dst_prefix_shape + (self._num_heads, self._out_feats),
)
else:
src_prefix_shape = dst_prefix_shape = tuple(feat.shape[:-1])
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape(
self.fc(h_src), src_prefix_shape + (self._num_heads, self._out_feats))
self.fc(h_src),
src_prefix_shape + (self._num_heads, self._out_feats),
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
h_dst = h_dst[:graph.number_of_dst_nodes()]
dst_prefix_shape = (graph.number_of_dst_nodes(),) + dst_prefix_shape[1:]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_dst = h_dst[: graph.number_of_dst_nodes()]
dst_prefix_shape = (
graph.number_of_dst_nodes(),
) + dst_prefix_shape[1:]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
......@@ -265,27 +292,27 @@ class GATConv(layers.Layer):
# which further speeds up computation and saves memory footprint.
el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)
er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = tf.reshape(self.res_fc(
h_dst), dst_prefix_shape + (-1, self._out_feats))
resval = tf.reshape(
self.res_fc(h_dst), dst_prefix_shape + (-1, self._out_feats)
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
"""Tensorflow modules for graph convolutions(GCN)."""
import numpy as np
# pylint: disable= no-member, arguments-differ, invalid-name
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from .... import function as fn
from ....base import DGLError
......@@ -134,18 +135,23 @@ class GraphConv(layers.Layer):
[ 2.1405895, -0.2574358],
[ 1.3607183, -0.1636453]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
norm='both',
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False):
def __init__(
self,
in_feats,
out_feats,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConv, self).__init__()
if norm not in ('none', 'both', 'right', 'left'):
raise DGLError('Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm))
if norm not in ("none", "both", "right", "left"):
raise DGLError(
'Invalid norm value. Must be either "none", "both", "right" or "left".'
' But got "{}".'.format(norm)
)
self._in_feats = in_feats
self._out_feats = out_feats
self._norm = norm
......@@ -153,15 +159,21 @@ class GraphConv(layers.Layer):
if weight:
xinit = tf.keras.initializers.glorot_uniform()
self.weight = tf.Variable(initial_value=xinit(
shape=(in_feats, out_feats), dtype='float32'), trainable=True)
self.weight = tf.Variable(
initial_value=xinit(
shape=(in_feats, out_feats), dtype="float32"
),
trainable=True,
)
else:
self.weight = None
if bias:
zeroinit = tf.keras.initializers.zeros()
self.bias = tf.Variable(initial_value=zeroinit(
shape=(out_feats), dtype='float32'), trainable=True)
self.bias = tf.Variable(
initial_value=zeroinit(shape=(out_feats), dtype="float32"),
trainable=True,
)
else:
self.bias = None
......@@ -216,23 +228,27 @@ class GraphConv(layers.Layer):
"""
with graph.local_scope():
if not self._allow_zero_in_degree:
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if tf.math.count_nonzero(graph.in_degrees() == 0) > 0:
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['both', 'left']:
degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
if self._norm == 'both':
if self._norm in ["both", "left"]:
degs = tf.clip_by_value(
tf.cast(graph.out_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
if self._norm == "both":
norm = tf.pow(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -242,9 +258,11 @@ class GraphConv(layers.Layer):
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
raise DGLError(
"External weight is provided while at the same time the"
" module has defined its own weight parameter. Please"
" create the module with flag weight=False."
)
else:
weight = self.weight
......@@ -252,24 +270,28 @@ class GraphConv(layers.Layer):
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat_src = tf.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
graph.srcdata["h"] = feat_src
graph.update_all(
fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
)
rst = graph.dstdata["h"]
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u(u='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
graph.srcdata["h"] = feat_src
graph.update_all(
fn.copy_u(u="h", out="m"), fn.sum(msg="m", out="h")
)
rst = graph.dstdata["h"]
if weight is not None:
rst = tf.matmul(rst, weight)
if self._norm in ['both', 'right']:
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
if self._norm == 'both':
if self._norm in ["both", "right"]:
degs = tf.clip_by_value(
tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf,
)
if self._norm == "both":
norm = tf.pow(degs, -0.5)
else:
norm = 1.0 / degs
......@@ -289,8 +311,8 @@ class GraphConv(layers.Layer):
"""Set the extra representation of the module,
which will come into effect when printing the model.
"""
summary = 'in={_in_feats}, out={_out_feats}'
summary += ', normalization={_norm}'
if '_activation' in self.__dict__:
summary += ', activation={_activation}'
summary = "in={_in_feats}, out={_out_feats}"
summary += ", normalization={_norm}"
if "_activation" in self.__dict__:
summary += ", activation={_activation}"
return summary.format(**self.__dict__)
......@@ -5,7 +5,7 @@ from tensorflow.keras import layers
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
from ....utils import check_eq_shape, expand_as_pair
class SAGEConv(layers.Layer):
......@@ -88,20 +88,25 @@ class SAGEConv(layers.Layer):
[ 0.3221837 , -0.29876417],
[-0.63356155, 0.09390211]], dtype=float32)>
"""
def __init__(self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None):
def __init__(
self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.0,
bias=True,
norm=None,
activation=None,
):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
valid_aggre_types = {"mean", "gcn", "pool", "lstm"}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
"Invalid aggregator_type. Must be one of {}. "
"But got {!r} instead.".format(
valid_aggre_types, aggregator_type
)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
......@@ -111,11 +116,11 @@ class SAGEConv(layers.Layer):
self.feat_drop = layers.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
if aggregator_type == "pool":
self.fc_pool = layers.Dense(self._in_src_feats)
if aggregator_type == 'lstm':
if aggregator_type == "lstm":
self.lstm = layers.LSTM(units=self._in_src_feats)
if aggregator_type != 'gcn':
if aggregator_type != "gcn":
self.fc_self = layers.Dense(out_feats, use_bias=bias)
self.fc_neigh = layers.Dense(out_feats, use_bias=bias)
......@@ -124,9 +129,9 @@ class SAGEConv(layers.Layer):
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
m = nodes.mailbox["m"] # (B, L, D)
rst = self.lstm(m)
return {'neigh': rst}
return {"neigh": rst}
def call(self, graph, feat):
r"""Compute GraphSAGE layer.
......@@ -155,41 +160,47 @@ class SAGEConv(layers.Layer):
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
h_self = feat_dst
# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = tf.cast(tf.zeros(
(graph.number_of_dst_nodes(), self._in_src_feats)), tf.float32)
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.dstdata["neigh"] = tf.cast(
tf.zeros((graph.number_of_dst_nodes(), self._in_src_feats)),
tf.float32,
)
if self._aggre_type == "mean":
graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u("h", "m"), fn.mean("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "gcn":
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
graph.srcdata["h"] = feat_src
graph.dstdata["h"] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "neigh"))
# divide in_degrees
degs = tf.cast(graph.in_degrees(), tf.float32)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
) / (tf.expand_dims(degs, -1) + 1)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
h_neigh = (graph.dstdata["neigh"] + graph.dstdata["h"]) / (
tf.expand_dims(degs, -1) + 1
)
elif self._aggre_type == "pool":
graph.srcdata["h"] = tf.nn.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u("h", "m"), fn.max("m", "neigh"))
h_neigh = graph.dstdata["neigh"]
elif self._aggre_type == "lstm":
graph.srcdata["h"] = feat_src
graph.update_all(fn.copy_u("h", "m"), self._lstm_reducer)
h_neigh = graph.dstdata["neigh"]
else:
raise KeyError(
'Aggregator type {} not recognized.'.format(self._aggre_type))
"Aggregator type {} not recognized.".format(
self._aggre_type
)
)
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
if self._aggre_type == "gcn":
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment