Unverified Commit f25bc176 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] Improve speed of several Hetero APIs (#1486)

* add clone function to frame

* add utest

* replace all local_var with local_scope

* fix utest

* avoid creating canonical types in __getitem__

* lint

* try another utest  appraoch for mx

* utest
parent 3c4506e9
...@@ -128,13 +128,13 @@ class Sequential(gluon.nn.Sequential): ...@@ -128,13 +128,13 @@ class Sequential(gluon.nn.Sequential):
>>> def __init__(self, **kwargs): >>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs) >>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat, e_feat): >>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var() >>> with graph.local_scope():
>>> graph.ndata['h'] = n_feat >>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h'] >>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e')) >>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e'] >>> e_feat += graph.edata['e']
>>> return n_feat, e_feat >>> return n_feat, e_feat
>>> >>>
>>> g = dgl.DGLGraph() >>> g = dgl.DGLGraph()
>>> g.add_nodes(3) >>> g.add_nodes(3)
...@@ -175,11 +175,11 @@ class Sequential(gluon.nn.Sequential): ...@@ -175,11 +175,11 @@ class Sequential(gluon.nn.Sequential):
>>> def __init__(self, **kwargs): >>> def __init__(self, **kwargs):
>>> super().__init__(**kwargs) >>> super().__init__(**kwargs)
>>> def forward(self, graph, n_feat): >>> def forward(self, graph, n_feat):
>>> graph = graph.local_var() >>> with graph.local_scope():
>>> graph.ndata['h'] = n_feat >>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h'] >>> n_feat += graph.ndata['h']
>>> return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1) >>> return n_feat.reshape(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>> >>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)) >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)) >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
......
...@@ -58,17 +58,16 @@ class AGNNConv(nn.Module): ...@@ -58,17 +58,16 @@ class AGNNConv(nn.Module):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
feat_src, feat_dst = expand_as_pair(feat) graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1)
graph.srcdata['norm_h'] = F.normalize(feat_src, p=2, dim=-1) if isinstance(feat, tuple):
if isinstance(feat, tuple): graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1)
graph.dstdata['norm_h'] = F.normalize(feat_dst, p=2, dim=-1) # compute cosine distance
# compute cosine distance graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos'))
graph.apply_edges(fn.u_dot_v('norm_h', 'norm_h', 'cos')) cos = graph.edata.pop('cos')
cos = graph.edata.pop('cos') e = self.beta * cos
e = self.beta * cos graph.edata['p'] = edge_softmax(graph, e)
graph.edata['p'] = edge_softmax(graph, e) graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h'))
graph.update_all(fn.u_mul_e('h', 'p', 'm'), fn.sum('m', 'h')) return graph.dstdata.pop('h')
return graph.dstdata.pop('h')
...@@ -53,21 +53,21 @@ class APPNPConv(nn.Module): ...@@ -53,21 +53,21 @@ class APPNPConv(nn.Module):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) norm = th.reshape(norm, shp).to(feat.device)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
th.ones(graph.number_of_edges(), 1).to(feat.device)) th.ones(graph.number_of_edges(), 1).to(feat.device))
graph.update_all(fn.u_mul_e('h', 'w', 'm'), graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
...@@ -217,12 +217,12 @@ class AtomicConv(nn.Module): ...@@ -217,12 +217,12 @@ class AtomicConv(nn.Module):
Updated node representations. V for the number of nodes, K for the Updated node representations. V for the number of nodes, K for the
number of radial filters, and T for the number of types of atomic numbers. number of radial filters, and T for the number of types of atomic numbers.
""" """
radial_pooled_values = self.radial_pooling(distances) # (K, E, 1) with graph.local_scope():
graph = graph.local_var() radial_pooled_values = self.radial_pooling(distances) # (K, E, 1)
if self.features_to_use is not None: if self.features_to_use is not None:
feat = (feat == self.features_to_use).float() # (V, T) feat = (feat == self.features_to_use).float() # (V, T)
graph.ndata['hv'] = feat graph.ndata['hv'] = feat
graph.edata['he'] = radial_pooled_values.transpose(1, 0).squeeze(-1) # (E, K) graph.edata['he'] = radial_pooled_values.transpose(1, 0).squeeze(-1) # (E, K)
graph.update_all(msg_func, reduce_func) graph.update_all(msg_func, reduce_func)
return graph.ndata['hv_new'].view(graph.number_of_nodes(), -1) # (V, K * T) return graph.ndata['hv_new'].view(graph.number_of_nodes(), -1) # (V, K * T)
...@@ -90,8 +90,8 @@ class CFConv(nn.Module): ...@@ -90,8 +90,8 @@ class CFConv(nn.Module):
float32 tensor of shape (V, out_feats) float32 tensor of shape (V, out_feats)
Updated node representations. Updated node representations.
""" """
g = g.local_var() with g.local_scope():
g.ndata['hv'] = self.project_node(node_feats) g.ndata['hv'] = self.project_node(node_feats)
g.edata['he'] = self.project_edge(edge_feats) g.edata['he'] = self.project_edge(edge_feats)
g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h')) g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h'))
return self.project_out(g.ndata['h']) return self.project_out(g.ndata['h'])
...@@ -118,44 +118,44 @@ class GATConv(nn.Module): ...@@ -118,44 +118,44 @@ class GATConv(nn.Module):
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view( feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats)
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into # We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then # [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to # Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er}) graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.dstdata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats) resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
return rst return rst
...@@ -77,22 +77,22 @@ class GatedGraphConv(nn.Module): ...@@ -77,22 +77,22 @@ class GatedGraphConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is the output feature size. is the output feature size.
""" """
assert graph.is_homograph(), \ with graph.local_scope():
"not a homograph; convert it with to_homo and pass in the edge type as argument" assert graph.is_homograph(), \
graph = graph.local_var() "not a homograph; convert it with to_homo and pass in the edge type as argument"
zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1])) zero_pad = feat.new_zeros((feat.shape[0], self._out_feats - feat.shape[1]))
feat = th.cat([feat, zero_pad], -1) feat = th.cat([feat, zero_pad], -1)
for _ in range(self._n_steps): for _ in range(self._n_steps):
graph.ndata['h'] = feat graph.ndata['h'] = feat
for i in range(self._n_etypes): for i in range(self._n_etypes):
eids = (etypes == i).nonzero().view(-1) eids = (etypes == i).nonzero().view(-1)
if len(eids) > 0: if len(eids) > 0:
graph.apply_edges( graph.apply_edges(
lambda edges: {'W_e*h': self.linears[i](edges.src['h'])}, lambda edges: {'W_e*h': self.linears[i](edges.src['h'])},
eids eids
) )
graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a')) graph.update_all(fn.copy_e('W_e*h', 'm'), fn.sum('m', 'a'))
a = graph.ndata.pop('a') # (N, D) a = graph.ndata.pop('a') # (N, D)
feat = self.gru(a, feat) feat = self.gru(a, feat)
return feat return feat
...@@ -72,11 +72,11 @@ class GINConv(nn.Module): ...@@ -72,11 +72,11 @@ class GINConv(nn.Module):
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -125,57 +125,56 @@ class GraphConv(nn.Module): ...@@ -125,57 +125,56 @@ class GraphConv(nn.Module):
torch.Tensor torch.Tensor
The output feature The output feature
""" """
graph = graph.local_var() with graph.local_scope():
if self._norm == 'both':
degs = graph.out_degrees().to(feat.device).float().clamp(min=1)
norm = th.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp)
feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = th.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)
if self._norm != 'none':
degs = graph.in_degrees().to(feat.device).float().clamp(min=1)
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().to(feat.device).float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
else: shp = norm.shape + (1,) * (feat.dim() - 1)
norm = 1.0 / degs norm = th.reshape(norm, shp)
shp = norm.shape + (1,) * (feat.dim() - 1) feat = feat * norm
norm = th.reshape(norm, shp)
rst = rst * norm
if self.bias is not None: if weight is not None:
rst = rst + self.bias if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
if self._activation is not None: ' module has defined its own weight parameter. Please'
rst = self._activation(rst) ' create the module with flag weight=False.')
else:
return rst weight = self.weight
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = th.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = th.matmul(rst, weight)
if self._norm != 'none':
degs = graph.in_degrees().to(feat.device).float().clamp(min=1)
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self): def extra_repr(self):
"""Set the extra representation of the module, """Set the extra representation of the module,
......
...@@ -114,48 +114,47 @@ class SAGEConv(nn.Module): ...@@ -114,48 +114,47 @@ class SAGEConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple):
if isinstance(feat, tuple): feat_src = self.feat_drop(feat[0])
feat_src = self.feat_drop(feat[0]) feat_dst = self.feat_drop(feat[1])
feat_dst = self.feat_drop(feat[1]) else:
else: feat_src = feat_dst = self.feat_drop(feat)
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
h_self = feat_dst
if self._aggre_type == 'mean':
if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn':
elif self._aggre_type == 'gcn': check_eq_shape(feat)
check_eq_shape(feat) graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) # divide in_degrees
# divide in_degrees degs = graph.in_degrees().to(feat_dst)
degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'pool':
elif self._aggre_type == 'pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'lstm':
elif self._aggre_type == 'lstm': graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] else:
else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
# GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn':
if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh)
rst = self.fc_neigh(h_neigh) else:
else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation
# activation if self.activation is not None:
if self.activation is not None: rst = self.activation(rst)
rst = self.activation(rst) # normalization
# normalization if self.norm is not None:
if self.norm is not None: rst = self.norm(rst)
rst = self.norm(rst) return rst
return rst
...@@ -77,27 +77,27 @@ class SGConv(nn.Module): ...@@ -77,27 +77,27 @@ class SGConv(nn.Module):
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results. training, or you will get wrong results.
""" """
graph = graph.local_var() with graph.local_scope():
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
# compute normalization # compute normalization
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1) norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A^k D)^k X # compute (D^-1 A^k D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
feat = self.norm(feat) feat = self.norm(feat)
# cache feature # cache feature
if self._cached: if self._cached:
self._cached_h = feat self._cached_h = feat
return self.fc(feat) return self.fc(feat)
...@@ -73,29 +73,29 @@ class TAGConv(nn.Module): ...@@ -73,29 +73,29 @@ class TAGConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
assert graph.is_homograph(), 'Graph is not homogeneous' with graph.local_scope():
graph = graph.local_var() assert graph.is_homograph(), 'Graph is not homogeneous'
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) norm = th.reshape(norm, shp).to(feat.device)
#D-1/2 A D -1/2 X #D-1/2 A D -1/2 X
fstack = [feat] fstack = [feat]
for _ in range(self._k): for _ in range(self._k):
rst = fstack[-1] * norm rst = fstack[-1] * norm
graph.ndata['h'] = rst graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.ndata['h']
rst = rst * norm rst = rst * norm
fstack.append(rst) fstack.append(rst)
rst = self.lin(th.cat(fstack, dim=-1)) rst = self.lin(th.cat(fstack, dim=-1))
if self._activation is not None: if self._activation is not None:
rst = self._activation(rst) rst = self._activation(rst)
return rst return rst
...@@ -130,13 +130,13 @@ class Sequential(nn.Sequential): ...@@ -130,13 +130,13 @@ class Sequential(nn.Sequential):
>>> def __init__(self): >>> def __init__(self):
>>> super().__init__() >>> super().__init__()
>>> def forward(self, graph, n_feat, e_feat): >>> def forward(self, graph, n_feat, e_feat):
>>> graph = graph.local_var() >>> with graph.local_scope():
>>> graph.ndata['h'] = n_feat >>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h'] >>> n_feat += graph.ndata['h']
>>> graph.apply_edges(fn.u_add_v('h', 'h', 'e')) >>> graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
>>> e_feat += graph.edata['e'] >>> e_feat += graph.edata['e']
>>> return n_feat, e_feat >>> return n_feat, e_feat
>>> >>>
>>> g = dgl.DGLGraph() >>> g = dgl.DGLGraph()
>>> g.add_nodes(3) >>> g.add_nodes(3)
...@@ -169,11 +169,11 @@ class Sequential(nn.Sequential): ...@@ -169,11 +169,11 @@ class Sequential(nn.Sequential):
>>> def __init__(self): >>> def __init__(self):
>>> super().__init__() >>> super().__init__()
>>> def forward(self, graph, n_feat): >>> def forward(self, graph, n_feat):
>>> graph = graph.local_var() >>> with graph.local_scope():
>>> graph.ndata['h'] = n_feat >>> graph.ndata['h'] = n_feat
>>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) >>> graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
>>> n_feat += graph.ndata['h'] >>> n_feat += graph.ndata['h']
>>> return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1) >>> return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
>>> >>>
>>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05)) >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
>>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2)) >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
......
...@@ -55,23 +55,23 @@ class APPNPConv(layers.Layer): ...@@ -55,23 +55,23 @@ class APPNPConv(layers.Layer):
The output feature of shape :math:`(N, *)` where :math:`*` The output feature of shape :math:`(N, *)` where :math:`*`
should be the same as input shape. should be the same as input shape.
""" """
graph = graph.local_var() with graph.local_scope():
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32), degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1, clip_value_max=np.inf) clip_value_min=1, clip_value_max=np.inf)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp) norm = tf.reshape(norm, shp)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
tf.ones(graph.number_of_edges(), 1)) tf.ones(graph.number_of_edges(), 1))
graph.update_all(fn.u_mul_e('h', 'w', 'm'), graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
...@@ -112,45 +112,45 @@ class GATConv(layers.Layer): ...@@ -112,45 +112,45 @@ class GATConv(layers.Layer):
The output feature of shape :math:`(N, H, D_{out})` where :math:`H` The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
is the number of heads, and :math:`D_{out}` is size of output feature. is the number of heads, and :math:`D_{out}` is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats)) feat_src = tf.reshape(self.fc_src(h_src), (-1, self._num_heads, self._out_feats))
feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats)) feat_dst = tf.reshape(self.fc_dst(h_dst), (-1, self._num_heads, self._out_feats))
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = tf.reshape( feat_src = feat_dst = tf.reshape(
self.fc(h_src), (-1, self._num_heads, self._out_feats)) self.fc(h_src), (-1, self._num_heads, self._out_feats))
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into # We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then # [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to # Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v, # addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True) el = tf.reduce_sum(feat_src * self.attn_l, axis=-1, keepdims=True)
er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True) er = tf.reduce_sum(feat_dst * self.attn_r, axis=-1, keepdims=True)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er}) graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft')) fn.sum('m', 'ft'))
rst = graph.dstdata['ft'] rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = tf.reshape(self.res_fc( resval = tf.reshape(self.res_fc(
h_dst), (h_dst.shape[0], -1, self._out_feats)) h_dst), (h_dst.shape[0], -1, self._out_feats))
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
return rst return rst
...@@ -70,11 +70,11 @@ class GINConv(layers.Layer): ...@@ -70,11 +70,11 @@ class GINConv(layers.Layer):
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
graph = graph.local_var() with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat) feat_src, feat_dst = expand_as_pair(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
return rst return rst
...@@ -122,61 +122,60 @@ class GraphConv(layers.Layer): ...@@ -122,61 +122,60 @@ class GraphConv(layers.Layer):
tf.Tensor tf.Tensor
The output feature The output feature
""" """
graph = graph.local_var() with graph.local_scope():
if self._norm == 'both':
degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp)
feat = feat * norm
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = tf.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = tf.matmul(rst, weight)
if self._norm != 'none':
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
if self._norm == 'both': if self._norm == 'both':
degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
else: shp = norm.shape + (1,) * (feat.ndim - 1)
norm = 1.0 / degs norm = tf.reshape(norm, shp)
shp = norm.shape + (1,) * (feat.ndim - 1) feat = feat * norm
norm = tf.reshape(norm, shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
if weight is not None:
feat = tf.matmul(feat, weight)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
if weight is not None:
rst = tf.matmul(rst, weight)
if self._norm != 'none':
degs = tf.clip_by_value(tf.cast(graph.in_degrees(), tf.float32),
clip_value_min=1,
clip_value_max=np.inf)
if self._norm == 'both':
norm = tf.pow(degs, -0.5)
else:
norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1)
norm = tf.reshape(norm, shp)
rst = rst * norm
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self): def extra_repr(self):
"""Set the extra representation of the module, """Set the extra representation of the module,
......
...@@ -100,49 +100,48 @@ class SAGEConv(layers.Layer): ...@@ -100,49 +100,48 @@ class SAGEConv(layers.Layer):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
graph = graph.local_var() with graph.local_scope():
if isinstance(feat, tuple):
if isinstance(feat, tuple): feat_src = self.feat_drop(feat[0])
feat_src = self.feat_drop(feat[0]) feat_dst = self.feat_drop(feat[1])
feat_dst = self.feat_drop(feat[1]) else:
else: feat_src = feat_dst = self.feat_drop(feat)
feat_src = feat_dst = self.feat_drop(feat)
h_self = feat_dst
h_self = feat_dst
if self._aggre_type == 'mean':
if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn':
elif self._aggre_type == 'gcn': check_eq_shape(feat)
check_eq_shape(feat) graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) # divide in_degrees
# divide in_degrees degs = tf.cast(graph.in_degrees(), tf.float32)
degs = tf.cast(graph.in_degrees(), tf.float32) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h'] ) / (tf.expand_dims(degs, -1) + 1)
) / (tf.expand_dims(degs, -1) + 1) elif self._aggre_type == 'pool':
elif self._aggre_type == 'pool': graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src))
graph.srcdata['h'] = tf.nn.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'lstm':
elif self._aggre_type == 'lstm': graph.srcdata['h'] = feat_src
graph.srcdata['h'] = feat_src graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) h_neigh = graph.dstdata['neigh']
h_neigh = graph.dstdata['neigh'] else:
else: raise KeyError(
raise KeyError( 'Aggregator type {} not recognized.'.format(self._aggre_type))
'Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE GCN does not require fc_self.
# GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn':
if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh)
rst = self.fc_neigh(h_neigh) else:
else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) # activation
# activation if self.activation is not None:
if self.activation is not None: rst = self.activation(rst)
rst = self.activation(rst) # normalization
# normalization if self.norm is not None:
if self.norm is not None: rst = self.norm(rst)
rst = self.norm(rst) return rst
return rst
...@@ -72,28 +72,28 @@ class SGConv(layers.Layer): ...@@ -72,28 +72,28 @@ class SGConv(layers.Layer):
If ``cache`` is se to True, ``feat`` and ``graph`` should not change during If ``cache`` is se to True, ``feat`` and ``graph`` should not change during
training, or you will get wrong results. training, or you will get wrong results.
""" """
graph = graph.local_var() with graph.local_scope():
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
# compute normalization # compute normalization
degs = tf.clip_by_value(tf.cast( degs = tf.clip_by_value(tf.cast(
graph.in_degrees(), tf.float32), clip_value_min=1, clip_value_max=np.inf) graph.in_degrees(), tf.float32), clip_value_min=1, clip_value_max=np.inf)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
norm = tf.expand_dims(norm, 1) norm = tf.expand_dims(norm, 1)
# compute (D^-1 A^k D)^k X # compute (D^-1 A^k D)^k X
for _ in range(self._k): for _ in range(self._k):
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(fn.copy_u('h', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
feat = self.norm(feat) feat = self.norm(feat)
# cache feature # cache feature
if self._cached: if self._cached:
self._cached_h = feat self._cached_h = feat
return self.fc(feat) return self.fc(feat)
...@@ -12,24 +12,24 @@ def edge_softmax_real(graph, score, eids=ALL): ...@@ -12,24 +12,24 @@ def edge_softmax_real(graph, score, eids=ALL):
"""Edge Softmax function""" """Edge Softmax function"""
if not is_all(eids): if not is_all(eids):
graph = graph.edge_subgraph(tf.cast(eids, tf.int64)) graph = graph.edge_subgraph(tf.cast(eids, tf.int64))
g = graph.local_var() with graph.local_scope():
g.edata['s'] = score graph.edata['s'] = score
g.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax')) graph.update_all(fn.copy_e('s', 'm'), fn.max('m', 'smax'))
g.apply_edges(fn.e_sub_v('s', 'smax', 'out')) graph.apply_edges(fn.e_sub_v('s', 'smax', 'out'))
g.edata['out'] = tf.math.exp(g.edata['out']) graph.edata['out'] = tf.math.exp(graph.edata['out'])
g.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum')) graph.update_all(fn.copy_e('out', 'm'), fn.sum('m', 'out_sum'))
g.apply_edges(fn.e_div_v('out', 'out_sum', 'out')) graph.apply_edges(fn.e_div_v('out', 'out_sum', 'out'))
out = g.edata['out'] out = graph.edata['out']
def edge_softmax_backward(grad_out): def edge_softmax_backward(grad_out):
g = graph.local_var() with graph.local_scope():
# clear backward cache explicitly # clear backward cache explicitly
g.edata['out'] = out graph.edata['out'] = out
g.edata['grad_s'] = out * grad_out graph.edata['grad_s'] = out * grad_out
g.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum')) graph.update_all(fn.copy_e('grad_s', 'm'), fn.sum('m', 'accum'))
g.apply_edges(fn.e_mul_v('out', 'accum', 'out')) graph.apply_edges(fn.e_mul_v('out', 'accum', 'out'))
grad_score = g.edata['grad_s'] - g.edata['out'] grad_score = graph.edata['grad_s'] - graph.edata['out']
return grad_score return grad_score
return out, edge_softmax_backward return out, edge_softmax_backward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment