"tests/python/common/sampling/test_sampling.py" did not exist on "86c81b4e927d94ed2dba76fc04e2088c6931e6b5"
Unverified Commit 7b4b8129 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Add edge_weight parameters to nn modules (#3455)

* add

* fix optional docing

* fix

* lint

* add normalized by edge weight

* add test

* fix

* lint

* fix docs

* fix

* fix docs
parent 9e7fbf95
......@@ -4,6 +4,8 @@ import torch as th
from torch import nn
from .... import function as fn
from .graphconv import EdgeWeightNorm
class APPNPConv(nn.Module):
r"""
......@@ -57,6 +59,7 @@ class APPNPConv(nn.Module):
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000]])
"""
def __init__(self,
k,
alpha,
......@@ -66,7 +69,7 @@ class APPNPConv(nn.Module):
self._alpha = alpha
self.edge_drop = nn.Dropout(edge_drop)
def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
......@@ -80,6 +83,11 @@ class APPNPConv(nn.Module):
feat : torch.Tensor
The input feature of shape :math:`(N, *)`. :math:`N` is the
number of nodes, and :math:`*` could be of any shape.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns
-------
......@@ -88,23 +96,33 @@ class APPNPConv(nn.Module):
should be the same as input shape.
"""
with graph.local_scope():
src_norm = th.pow(graph.out_degrees().float().clamp(min=1), -0.5)
shp = src_norm.shape + (1,) * (feat.dim() - 1)
src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device)
if edge_weight is None:
src_norm = th.pow(
graph.out_degrees().float().clamp(min=1), -0.5)
shp = src_norm.shape + (1,) * (feat.dim() - 1)
src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow(
graph.in_degrees().float().clamp(min=1), -0.5)
shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device)
else:
edge_weight = EdgeWeightNorm(
'both')(graph, edge_weight)
feat_0 = feat
for _ in range(self._k):
# normalization by src node
feat = feat * src_norm
if edge_weight is None:
feat = feat * src_norm
graph.ndata['h'] = feat
if edge_weight is None:
edge_weight = th.ones(graph.number_of_edges(), 1)
graph.edata['w'] = self.edge_drop(
th.ones(graph.number_of_edges(), 1).to(feat.device))
edge_weight).to(feat.device)
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
# normalization by dst node
feat = feat * dst_norm
if edge_weight is None:
feat = feat * dst_norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat
......@@ -8,6 +8,7 @@ from torch import nn
from .... import function as fn
from ....base import DGLError
from .graphconv import EdgeWeightNorm
class GCN2Conv(nn.Module):
......@@ -34,7 +35,7 @@ class GCN2Conv(nn.Module):
:math:`\alpha` is the fraction of initial node features, and
:math:`\beta_l` is the hyperparameter to tune the strength of identity mapping.
It is defined by :math:`\beta_l = \log(\frac{\lambda}{l}+1)\approx\frac{\lambda}{l}`,
where :math:`\lambda` is a hyperparameter. :math: `\beta` ensures that the decay of
where :math:`\lambda` is a hyperparameter. :math:`\beta` ensures that the decay of
the weight matrix adaptively increases as we stack more layers.
Parameters
......@@ -133,7 +134,8 @@ class GCN2Conv(nn.Module):
if self._project_initial_features:
self.register_parameter("weight2", None)
else:
self.weight2 = nn.Parameter(th.Tensor(self._in_feats, self._in_feats))
self.weight2 = nn.Parameter(
th.Tensor(self._in_feats, self._in_feats))
if self._bias:
self.bias = nn.Parameter(th.Tensor(self._in_feats))
......@@ -170,7 +172,7 @@ class GCN2Conv(nn.Module):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat, feat_0):
def forward(self, graph, feat, feat_0, edge_weight=None):
r"""
Description
......@@ -182,11 +184,17 @@ class GCN2Conv(nn.Module):
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape
The input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is the size of input feature and :math:`N` is the number of nodes.
feat_0 : torch.Tensor
The initial feature of shape :math:`(N, D_{in})`
The initial feature of shape :math:`(N, D_{in})`
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns
-------
......@@ -224,15 +232,24 @@ class GCN2Conv(nn.Module):
)
# normalize to get smoothed representation
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
if edge_weight is None:
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
else:
edge_weight = EdgeWeightNorm('both')(graph, edge_weight)
feat = feat * norm
if edge_weight is None:
feat = feat * norm
graph.ndata["h"] = feat
graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = edge_weight
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
graph.update_all(msg_func, fn.sum("m", "h"))
feat = graph.ndata.pop("h")
feat = feat * norm
if edge_weight is None:
feat = feat * norm
# scale
feat = feat * (1 - self.alpha)
......
......@@ -5,6 +5,8 @@ from torch import nn
from .... import function as fn
from ....base import DGLError
from .graphconv import EdgeWeightNorm
class SGConv(nn.Module):
r"""
......@@ -83,6 +85,7 @@ class SGConv(nn.Module):
[-1.9297, -0.9273],
[-1.9441, -0.9343]], grad_fn=<AddmmBackward>)
"""
def __init__(self,
in_feats,
out_feats,
......@@ -130,7 +133,7 @@ class SGConv(nn.Module):
"""
self._allow_zero_in_degree = set_value
def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
......@@ -144,6 +147,11 @@ class SGConv(nn.Module):
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns
-------
......@@ -176,21 +184,30 @@ class SGConv(nn.Module):
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
if self._cached_h is not None:
feat = self._cached_h
else:
# compute normalization
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
if edge_weight is None:
# compute normalization
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A^k D)^k X
for _ in range(self._k):
feat = feat * norm
if edge_weight is None:
feat = feat * norm
graph.ndata['h'] = feat
graph.update_all(fn.copy_u('h', 'm'),
graph.update_all(msg_func,
fn.sum('m', 'h'))
feat = graph.ndata.pop('h')
feat = feat * norm
if edge_weight is None:
feat = feat * norm
if self.norm is not None:
feat = self.norm(feat)
......
......@@ -4,6 +4,7 @@ import torch as th
from torch import nn
from .... import function as fn
from .graphconv import EdgeWeightNorm
class TAGConv(nn.Module):
......@@ -59,6 +60,7 @@ class TAGConv(nn.Module):
[ 0.5215, -1.6044],
[ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
"""
def __init__(self,
in_feats,
out_feats,
......@@ -89,7 +91,7 @@ class TAGConv(nn.Module):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.lin.weight, gain=gain)
def forward(self, graph, feat):
def forward(self, graph, feat, edge_weight=None):
r"""
Description
......@@ -103,6 +105,11 @@ class TAGConv(nn.Module):
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
edge_weight: torch.Tensor, optional
edge_weight to use in the message passing process. This is equivalent to
using weighted adjacency matrix in the equation above, and
:math:\tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}
is based on :class:`dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm`.
Returns
-------
......@@ -112,22 +119,30 @@ class TAGConv(nn.Module):
"""
with graph.local_scope():
assert graph.is_homogeneous, 'Graph is not homogeneous'
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
#D-1/2 A D -1/2 X
if edge_weight is None:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
msg_func = fn.copy_u("h", "m")
if edge_weight is not None:
graph.edata["_edge_weight"] = EdgeWeightNorm(
'both')(graph, edge_weight)
msg_func = fn.u_mul_e("h", "_edge_weight", "m")
# D-1/2 A D -1/2 X
fstack = [feat]
for _ in range(self._k):
rst = fstack[-1] * norm
if edge_weight is None:
rst = fstack[-1] * norm
else:
rst = fstack[-1]
graph.ndata['h'] = rst
graph.update_all(fn.copy_src(src='h', out='m'),
graph.update_all(msg_func,
fn.sum(msg='m', out='h'))
rst = graph.ndata['h']
rst = rst * norm
if edge_weight is None:
rst = rst * norm
fstack.append(rst)
rst = self.lin(th.cat(fstack, dim=-1))
......
......@@ -715,6 +715,60 @@ def test_appnp_conv(g, idtype):
h = appnp(g, feat)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_appnp_conv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
appnp = nn.APPNPConv(10, 0.1)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
appnp = appnp.to(ctx)
h = appnp(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_gcn2conv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
gcn2conv = nn.GCN2Conv(5, layer=2, alpha=0.5,
project_initial_features=True)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
gcn2conv = gcn2conv.to(ctx)
res = feat
h = gcn2conv(g, res, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_sgconv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
sgconv = nn.SGConv(5, 5, 3)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
sgconv = sgconv.to(ctx)
h = sgconv(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
def test_tagconv_e_weight(g, idtype):
ctx = F.ctx()
g = g.astype(idtype).to(ctx)
conv = nn.TAGConv(5, 5, bias=True)
conv = conv.to(ctx)
feat = F.randn((g.number_of_nodes(), 5))
eweight = F.ones((g.num_edges(), ))
conv = conv.to(ctx)
h = conv(g, feat, edge_weight=eweight)
assert h.shape[-1] == 5
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment