Unverified Commit e296c468 authored by bgawrych's avatar bgawrych Committed by GitHub
Browse files

[Optimization] Optimize bias term in SageConv layer (#4747)



* Optimize bias term in sageconv

* fix lint

* Remove bias sharing

* Update sageconv.py
Co-authored-by: default avatarBartlomiej Gawrych <barlomiej.gawrych@intel.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 56ffb650
...@@ -6,7 +6,7 @@ from torch.nn import functional as F ...@@ -6,7 +6,7 @@ from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape, dgl_warning from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
...@@ -116,18 +116,20 @@ class SAGEConv(nn.Module): ...@@ -116,18 +116,20 @@ class SAGEConv(nn.Module):
self.norm = norm self.norm = norm
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation self.activation = activation
# aggregator type: mean/pool/lstm/gcn # aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool': if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm': if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if bias:
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
elif bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats)) self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -151,19 +153,6 @@ class SAGEConv(nn.Module): ...@@ -151,19 +153,6 @@ class SAGEConv(nn.Module):
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def _compatibility_check(self):
"""Address the backward compatibility issue brought by #2747"""
if not hasattr(self, 'bias'):
dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, "
"DGL automatically convert it to be compatible with latest version.")
bias = self.fc_neigh.bias
self.fc_neigh.bias = None
if hasattr(self, 'fc_self'):
if bias is not None:
bias = bias + self.fc_self.bias
self.fc_self.bias = None
self.bias = bias
def _lstm_reducer(self, nodes): def _lstm_reducer(self, nodes):
"""LSTM reducer """LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing) NOTE(zihao): lstm reducer with default schedule (degree bucketing)
...@@ -204,7 +193,6 @@ class SAGEConv(nn.Module): ...@@ -204,7 +193,6 @@ class SAGEConv(nn.Module):
where :math:`N_{dst}` is the number of destination nodes in the input graph, where :math:`N_{dst}` is the number of destination nodes in the input graph,
:math:`D_{out}` is the size of the output feature. :math:`D_{out}` is the size of the output feature.
""" """
self._compatibility_check()
with graph.local_scope(): with graph.local_scope():
if isinstance(feat, tuple): if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0]) feat_src = self.feat_drop(feat[0])
...@@ -266,12 +254,12 @@ class SAGEConv(nn.Module): ...@@ -266,12 +254,12 @@ class SAGEConv(nn.Module):
# GraphSAGE GCN does not require fc_self. # GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn': if self._aggre_type == 'gcn':
rst = h_neigh rst = h_neigh
# add bias manually for GCN
if self.bias is not None:
rst = rst + self.bias
else: else:
rst = self.fc_self(h_self) + h_neigh rst = self.fc_self(h_self) + h_neigh
# bias term
if self.bias is not None:
rst = rst + self.bias
# activation # activation
if self.activation is not None: if self.activation is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment