"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "84b006b27833d93045ae5552e2cebb13f5140ab5"
Unverified Commit 7b4b8129 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Add edge_weight parameters to nn modules (#3455)

* add

* fix optional docing

* fix

* lint

* add normalized by edge weight

* add test

* fix

* lint

* fix docs

* fix

* fix docs
parent 9e7fbf95
...@@ -4,6 +4,8 @@ import torch as th ...@@ -4,6 +4,8 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from .graphconv import EdgeWeightNorm
class APPNPConv(nn.Module): class APPNPConv(nn.Module):
r""" r"""
...@@ -57,6 +59,7 @@ class APPNPConv(nn.Module): ...@@ -57,6 +59,7 @@ class APPNPConv(nn.Module):
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000]]) 0.5000]])
""" """
def __init__(self, def __init__(self,
k, k,
alpha, alpha,
...@@ -66,7 +69,7 @@ class APPNPConv(nn.Module): ...@@ -66,7 +69,7 @@ class APPNPConv(nn.Module):
self._alpha = alpha self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop) self.edge_drop = nn.Dropout(edge_drop)
def forward(self, graph, feat): def forward(self, graph, feat, edge_weight=None):
r""" r"""
Description Description
...@@ -80,6 +83,11 @@ class APPNPConv(nn.Module): ...@@ -80,6 +83,11 @@ class APPNPConv(nn.Module):
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, *)`. :math:`N` is the The input feature of shape :math:`(N, *)`. :math:`N` is the
number of nodes, and :math:`*` could be of any shape. number of nodes, and :math:`*` could be of any shape.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns Returns
------- -------
...@@ -88,23 +96,33 @@ class APPNPConv(nn.Module): ...@@ -88,23 +96,33 @@ class APPNPConv(nn.Module):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
src_norm = th.pow(graph.out_degrees().float().clamp(min=1), -0.5) if edge_weight is None:
src_norm = th.pow(
graph.out_degrees().float().clamp(min=1), -0.5)
shp = src_norm.shape + (1,) * (feat.dim() - 1) shp = src_norm.shape + (1,) * (feat.dim() - 1)
src_norm = th.reshape(src_norm, shp).to(feat.device) src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) dst_norm = th.pow(
graph.in_degrees().float().clamp(min=1), -0.5)
shp = dst_norm.shape + (1,) * (feat.dim() - 1) shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device) dst_norm = th.reshape(dst_norm, shp).to(feat.device)
else:
edge_weight = EdgeWeightNorm(
'both')(graph, edge_weight)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
if edge_weight is None:
feat = feat * src_norm feat = feat * src_norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
if edge_weight is None:
edge_weight = th.ones(graph.number_of_edges(), 1)
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
th.ones(graph.number_of_edges(), 1).to(feat.device)) edge_weight).to(feat.device)
graph.update_all(fn.u_mul_e('h', 'w', 'm'), graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst node # normalization by dst node
if edge_weight is None:
feat = feat * dst_norm feat = feat * dst_norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from .graphconv import EdgeWeightNorm
class GCN2Conv(nn.Module): class GCN2Conv(nn.Module):
...@@ -34,7 +35,7 @@ class GCN2Conv(nn.Module): ...@@ -34,7 +35,7 @@ class GCN2Conv(nn.Module):
:math:`\alpha` is the fraction of initial node features, and :math:`\alpha` is the fraction of initial node features, and
:math:`\beta_l` is the hyperparameter to tune the strength of identity mapping. :math:`\beta_l` is the hyperparameter to tune the strength of identity mapping.
It is defined by :math:`\beta_l = \log(\frac{\lambda}{l}+1)\approx\frac{\lambda}{l}`, It is defined by :math:`\beta_l = \log(\frac{\lambda}{l}+1)\approx\frac{\lambda}{l}`,
where :math:`\lambda` is a hyperparameter. :math: `\beta` ensures that the decay of where :math:`\lambda` is a hyperparameter. :math:`\beta` ensures that the decay of
the weight matrix adaptively increases as we stack more layers. the weight matrix adaptively increases as we stack more layers.
Parameters Parameters
...@@ -133,7 +134,8 @@ class GCN2Conv(nn.Module): ...@@ -133,7 +134,8 @@ class GCN2Conv(nn.Module):
if self._project_initial_features: if self._project_initial_features:
self.register_parameter("weight2", None) self.register_parameter("weight2", None)
else: else:
self.weight2 = nn.Parameter(th.Tensor(self._in_feats, self._in_feats)) self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats))
if self._bias: if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats)) self.bias = nn.Parameter(th.Tensor(self._in_feats))
...@@ -170,7 +172,7 @@ class GCN2Conv(nn.Module): ...@@ -170,7 +172,7 @@ class GCN2Conv(nn.Module):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat, feat_0): def forward(self, graph, feat, feat_0, edge_weight=None):
r""" r"""
Description Description
...@@ -187,6 +189,12 @@ class GCN2Conv(nn.Module): ...@@ -187,6 +189,12 @@ class GCN2Conv(nn.Module):
where :math:`D_{in}` is the size of input feature and :math:`N` is the number of nodes. where :math:`D_{in}` is the size of input feature and :math:`N` is the number of nodes.
feat_0 : torch.Tensor feat_0 : torch.Tensor
The initial feature of shape :math:`(N, D_{in})` The initial feature of shape :math:`(N, D_{in})`
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns Returns
------- -------
...@@ -224,14 +232,23 @@ class GCN2Conv(nn.Module): ...@@ -224,14 +232,23 @@ class GCN2Conv(nn.Module):
) )
# normalize to get smoothed representation # normalize to get smoothed representation
if edge_weight is None:
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1) norm = norm.to(feat.device).unsqueeze(1)
else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight)
if edge_weight is None:
feat = feat * norm feat = feat * norm
graph.ndata["h"] = feat graph.ndata["h"] = feat
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = edge_weight
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
graph.update_all(msg_func, fn.sum("m", "h"))
feat = graph.ndata.pop("h") feat = graph.ndata.pop("h")
if edge_weight is None:
feat = feat * norm feat = feat * norm
# scale # scale
feat = feat * (1 - self.alpha) feat = feat * (1 - self.alpha)
......
...@@ -5,6 +5,8 @@ from torch import nn ...@@ -5,6 +5,8 @@ from torch import nn
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from .graphconv import EdgeWeightNorm
class SGConv(nn.Module): class SGConv(nn.Module):
r""" r"""
...@@ -83,6 +85,7 @@ class SGConv(nn.Module): ...@@ -83,6 +85,7 @@ class SGConv(nn.Module):
[-1.9297, -0.9273], [-1.9297, -0.9273],
[-1.9441, -0.9343]], grad_fn=<AddmmBackward>) [-1.9441, -0.9343]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
...@@ -130,7 +133,7 @@ class SGConv(nn.Module): ...@@ -130,7 +133,7 @@ class SGConv(nn.Module):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat): def forward(self, graph, feat, edge_weight=None):
r""" r"""
Description Description
...@@ -144,6 +147,11 @@ class SGConv(nn.Module): ...@@ -144,6 +147,11 @@ class SGConv(nn.Module):
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns Returns
------- -------
...@@ -176,20 +184,29 @@ class SGConv(nn.Module): ...@@ -176,20 +184,29 @@ class SGConv(nn.Module):
'to be `True` when constructing this module will ' 'to be `True` when constructing this module will '
'suppress the check and let the code run.') 'suppress the check and let the code run.')
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
if self._cached_h is not None: if self._cached_h is not None:
feat = self._cached_h feat = self._cached_h
else: else:
if edge_weight is None:
# compute normalization # compute normalization
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1) norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A^k D)^k X # compute (D^-1 A^k D)^k X
for _ in range(self._k): for _ in range(self._k):
if edge_weight is None:
feat = feat * norm feat = feat * norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'), graph.update_all(msg_func,
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
if edge_weight is None:
feat = feat * norm feat = feat * norm
if self.norm is not None: if self.norm is not None:
......
...@@ -4,6 +4,7 @@ import torch as th ...@@ -4,6 +4,7 @@ import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from .graphconv import EdgeWeightNorm
class TAGConv(nn.Module): class TAGConv(nn.Module):
...@@ -59,6 +60,7 @@ class TAGConv(nn.Module): ...@@ -59,6 +60,7 @@ class TAGConv(nn.Module):
[ 0.5215, -1.6044], [ 0.5215, -1.6044],
[ 0.3304, -1.9927]], grad_fn=<AddmmBackward>) [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
""" """
def __init__(self, def __init__(self,
in_feats, in_feats,
out_feats, out_feats,
...@@ -89,7 +91,7 @@ class TAGConv(nn.Module): ...@@ -89,7 +91,7 @@ class TAGConv(nn.Module):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain) nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, graph, feat): def forward(self, graph, feat, edge_weight=None):
r""" r"""
Description Description
...@@ -103,6 +105,11 @@ class TAGConv(nn.Module): ...@@ -103,6 +105,11 @@ class TAGConv(nn.Module):
feat : torch.Tensor feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes. is size of input feature, :math:`N` is the number of nodes.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns Returns
------- -------
...@@ -112,21 +119,29 @@ class TAGConv(nn.Module): ...@@ -112,21 +119,29 @@ class TAGConv(nn.Module):
""" """
with graph.local_scope(): with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous' assert graph.is_homogeneous, 'Graph is not homogeneous'
if edge_weight is None:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) norm = th.reshape(norm, shp).to(feat.device)
#D-1/2 A D -1/2 X msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
# D-1/2 A D -1/2 X
fstack = [feat] fstack = [feat]
for _ in range(self._k): for _ in range(self._k):
if edge_weight is None:
rst = fstack[-1] * norm rst = fstack[-1] * norm
else:
rst = fstack[-1]
graph.ndata['h'] = rst graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(msg_func,
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.ndata['h'] rst = graph.ndata['h']
if edge_weight is None:
rst = rst * norm rst = rst * norm
fstack.append(rst) fstack.append(rst)
......
...@@ -715,6 +715,60 @@ def test_appnp_conv(g, idtype): ...@@ -715,6 +715,60 @@ def test_appnp_conv(g, idtype):
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
appnp = appnp.to(ctx)
h = appnp(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5,
project_initial_features=True)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
gcn2conv = gcn2conv.to(ctx)
res = feat
h = gcn2conv(g, res, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgconv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
sgconv = nn.SGConv(5, 5, 3)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
sgconv = sgconv.to(ctx)
h = sgconv(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_tagconv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
conv = nn.TAGConv(5, 5, bias=True)
conv = conv.to(ctx)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
conv = conv.to(ctx)
h = conv(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum']) @pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment