"...api/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "3f240fbb3734ab5f112a3d26d3856cf0a0e1a092"
Unverified Commit edf64463 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[performance] Optimize the association order of AXW in GraphSAGE. (#2747)



* upd

* lint

* upd

* upd

* compatibility

* upd

* upd
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 366cc7eb
...@@ -5,7 +5,7 @@ from torch import nn ...@@ -5,7 +5,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape from ....utils import expand_as_pair, check_eq_shape, dgl_warning
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -119,8 +119,12 @@ class SAGEConv(nn.Module): ...@@ -119,8 +119,12 @@ class SAGEConv(nn.Module):
if aggregator_type == 'lstm': if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn': if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -144,6 +148,19 @@ class SAGEConv(nn.Module): ...@@ -144,6 +148,19 @@ class SAGEConv(nn.Module):
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _compatibility_check(self):
"""Address the backward compatibility issue brought by #2747"""
if not hasattr(self, 'bias'):
dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
"DGL automatically convert it to be compatible with latest version.")
bias = self.fc_neigh.bias
self.fc_neigh.bias = None
if hasattr(self, 'fc_self'):
if bias is not None:
bias = bias + self.fc_self.bias
self.fc_self.bias = None
self.bias = bias
def _lstm_reducer(self, nodes): def _lstm_reducer(self, nodes):
"""LSTM reducer """LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing) NOTE(zihao): lstm reducer with default schedule (degree bucketing)
...@@ -183,6 +200,7 @@ class SAGEConv(nn.Module): ...@@ -183,6 +200,7 @@ class SAGEConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature. is size of output feature.
""" """
self._compatibility_check()
with graph.local_scope(): with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0]) feat_src = self.feat_drop(feat[0])
...@@ -191,11 +209,11 @@ class SAGEConv(nn.Module): ...@@ -191,11 +209,11 @@ class SAGEConv(nn.Module):
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm') msg_fn = fn.copy_src('h', 'm')
if edge_weight is not None: if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges() assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm') msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')
h_self = feat_dst h_self = feat_dst
...@@ -204,34 +222,50 @@ class SAGEConv(nn.Module): ...@@ -204,34 +222,50 @@ class SAGEConv(nn.Module):
graph.dstdata['neigh'] = torch.zeros( graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst) feat_dst.shape[0], self._in_src_feats).to(feat_dst)
# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats
# Message Passing
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.update_all(aggregate_fn, fn.mean('m', 'neigh')) graph.update_all(msg_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous if isinstance(feat, tuple): # heterogeneous
graph.update_all(aggregate_fn, fn.sum('m', 'neigh')) graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees # divide in_degrees
degs = graph.in_degrees().to(feat_dst) degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'pool': elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(aggregate_fn, fn.max('m', 'neigh')) graph.update_all(msg_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = self.fc_neigh(graph.dstdata['neigh'])
elif self._aggre_type == 'lstm': elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._lstm_reducer) graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh'] h_neigh = self.fc_neigh(graph.dstdata['neigh'])
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self. # GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn': if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh) rst = h_neigh
else: else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) rst = self.fc_self(h_self) + h_neigh
# bias term
if self.bias is not None:
rst = rst + self.bias
# activation # activation
if self.activation is not None: if self.activation is not None:
rst = self.activation(rst) rst = self.activation(rst)
......
...@@ -806,7 +806,7 @@ def test_dense_sage_conv(g, idtype, out_dim): ...@@ -806,7 +806,7 @@ def test_dense_sage_conv(g, idtype, out_dim):
sage = nn.SAGEConv(5, out_dim, 'gcn') sage = nn.SAGEConv(5, out_dim, 'gcn')
dense_sage = nn.DenseSAGEConv(5, out_dim) dense_sage = nn.DenseSAGEConv(5, out_dim)
dense_sage.fc.weight.data = sage.fc_neigh.weight.data dense_sage.fc.weight.data = sage.fc_neigh.weight.data
dense_sage.fc.bias.data = sage.fc_neigh.bias.data dense_sage.fc.bias.data = sage.bias.data
if len(g.ntypes) == 2: if len(g.ntypes) == 2:
feat = ( feat = (
F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_src_nodes(), 5)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment