Unverified Commit 0855d255 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[NN] Support scalar edge weight for GraphConv, SAGEConv and GINConv (#2557)

* add edge weight in forward

* fix lint

* fix

* fix

* address comments

* add utils

* add util to normalize in gcn way

* fix lint

* add unittest

* fix lint

* fix docstring

* fix docstring

* address comments

* improve notation consistence

* use preferred fn
parent 8900450d
...@@ -17,6 +17,13 @@ GraphConv ...@@ -17,6 +17,13 @@ GraphConv
:members: weight, bias, forward, reset_parameters :members: weight, bias, forward, reset_parameters
:show-inheritance: :show-inheritance:
EdgeWeightNorm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: dgl.nn.pytorch.conv.EdgeWeightNorm
:members: forward
:show-inheritance:
RelGraphConv RelGraphConv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -235,15 +235,15 @@ def src_mul_edge(src, edge, out): ...@@ -235,15 +235,15 @@ def src_mul_edge(src, edge, out):
---------- ----------
src : str src : str
The source feature field. The source feature field.
dst : str edge : str
The destination feature field. The edge feature field.
out : str out : str
The output message field. The output message field.
Examples Examples
-------- --------
>>> import dgl >>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'h', 'm') >>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
""" """
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out) return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)
......
...@@ -8,7 +8,7 @@ from .edgeconv import EdgeConv ...@@ -8,7 +8,7 @@ from .edgeconv import EdgeConv
from .gatconv import GATConv from .gatconv import GATConv
from .ginconv import GINConv from .ginconv import GINConv
from .gmmconv import GMMConv from .gmmconv import GMMConv
from .graphconv import GraphConv from .graphconv import GraphConv, EdgeWeightNorm
from .nnconv import NNConv from .nnconv import NNConv
from .relgraphconv import RelGraphConv from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv from .sageconv import SAGEConv
...@@ -22,7 +22,7 @@ from .atomicconv import AtomicConv ...@@ -22,7 +22,7 @@ from .atomicconv import AtomicConv
from .cfconv import CFConv from .cfconv import CFConv
from .dotgatconv import DotGatConv from .dotgatconv import DotGatConv
__all__ = ['GraphConv', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv'] 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv']
...@@ -20,6 +20,16 @@ class GINConv(nn.Module): ...@@ -20,6 +20,16 @@ class GINConv(nn.Module):
\mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right) \right\}\right)\right)
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
\mathrm{aggregate}\left(\left\{e_{ji} h_j^{l}, j\in\mathcal{N}(i)
\right\}\right)\right)
where :math:`e_{ji}` is the weight on the edge from node :math:`j` to node :math:`i`.
Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.
Parameters Parameters
---------- ----------
apply_func : callable activation function/layer or None apply_func : callable activation function/layer or None
...@@ -80,7 +90,7 @@ class GINConv(nn.Module): ...@@ -80,7 +90,7 @@ class GINConv(nn.Module):
else: else:
self.register_buffer('eps', th.FloatTensor([init_eps])) self.register_buffer('eps', th.FloatTensor([init_eps]))
def forward(self, graph, feat): def forward(self, graph, feat, edge_weight=None):
r""" r"""
Description Description
...@@ -98,6 +108,9 @@ class GINConv(nn.Module): ...@@ -98,6 +108,9 @@ class GINConv(nn.Module):
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`. :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``. fit the input dimensionality requirement of ``apply_func``.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns Returns
------- -------
...@@ -108,9 +121,15 @@ class GINConv(nn.Module): ...@@ -108,9 +121,15 @@ class GINConv(nn.Module):
as input dimensionality. as input dimensionality.
""" """
with graph.local_scope(): with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), self._reducer('m', 'neigh')) graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
......
...@@ -7,6 +7,135 @@ from torch.nn import init ...@@ -7,6 +7,135 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair from ....utils import expand_as_pair
from ....transform import reverse
from ....convert import block_to_graph
from ....heterograph import DGLBlock
class EdgeWeightNorm(nn.Module):
r"""
Description
-----------
This module normalizes positive scalar edge weights on a graph
following the form in `GCN <https://arxiv.org/abs/1609.02907>`__.
Mathematically, setting ``norm='both'`` yields the following normalization term:
.. math:
c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})
And, setting ``norm='right'`` yields the following normalization term:
.. math:
c_{ji} = (\sum_{k\in\mathcal{N}(i)}}e_{ki})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
The module returns the normalized weight :math:`e_{ji} / c_{ji}`.
Parameters
----------
norm : str, optional
The normalizer as specified above. Default is `'both'`.
eps : float, optional
A small offset value in the denominator. Default is 0.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import EdgeWeightNorm, GraphConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])
>>> norm = EdgeWeightNorm(norm='both')
>>> norm_edge_weight = norm(g, edge_weight)
>>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)
>>> res = conv(g, feat, edge_weight=norm_edge_weight)
>>> print(res)
tensor([[-1.1849, -0.7525],
[-1.3514, -0.8582],
[-1.2384, -0.7865],
[-1.9949, -1.2669],
[-1.3658, -0.8674],
[-0.8323, -0.5286]], grad_fn=<AddBackward0>)
"""
def __init__(self, norm='both', eps=0.):
super(EdgeWeightNorm, self).__init__()
self._norm = norm
self._eps = eps
def forward(self, graph, edge_weight):
r"""
Description
-----------
Compute normalized edge weight for the GCN model.
Parameters
----------
graph : DGLGraph
The graph.
edge_weight : torch.Tensor
Unnormalized scalar weights on the edges.
The shape is expected to be :math:`(|E|)`.
Returns
-------
torch.Tensor
The normalized edge weight.
Raises
------
DGLError
Case 1:
The edge weight is multi-dimensional. Currently this module
only supports a scalar weight on each edge.
Case 2:
The edge weight has non-positive values with ``norm='both'``.
This will trigger square root and division by a non-positive number.
"""
with graph.local_scope():
if isinstance(graph, DGLBlock):
graph = block_to_graph(graph)
if len(edge_weight.shape) > 1:
raise DGLError('Currently the normalization is only defined '
'on scalar edge weight. Please customize the '
'normalization for your high-dimensional weights.')
if self._norm == 'both' and th.any(edge_weight <= 0).item():
raise DGLError('Non-positive edge weight detected with `norm="both"`. '
'This leads to square root of zero or negative values.')
dev = graph.device
graph.srcdata['_src_out_w'] = th.ones((graph.number_of_src_nodes())).float().to(dev)
graph.dstdata['_dst_in_w'] = th.ones((graph.number_of_dst_nodes())).float().to(dev)
graph.edata['_edge_w'] = edge_weight
if self._norm == 'both':
reversed_g = reverse(graph)
reversed_g.edata['_edge_w'] = edge_weight
reversed_g.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'out_weight'))
degs = reversed_g.dstdata['out_weight'] + self._eps
norm = th.pow(degs, -0.5)
graph.srcdata['_src_out_w'] = norm
if self._norm != 'none':
graph.update_all(fn.copy_edge('_edge_w', 'm'), fn.sum('m', 'in_weight'))
degs = graph.dstdata['in_weight'] + self._eps
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
norm = 1.0 / degs
graph.dstdata['_dst_in_w'] = norm
graph.apply_edges(lambda e: {'_norm_edge_weights': e.src['_src_out_w'] * \
e.dst['_dst_in_w'] * \
e.data['_edge_w']})
return graph.edata['_norm_edge_weights']
# pylint: disable=W0235 # pylint: disable=W0235
class GraphConv(nn.Module): class GraphConv(nn.Module):
...@@ -18,13 +147,25 @@ class GraphConv(nn.Module): ...@@ -18,13 +147,25 @@ class GraphConv(nn.Module):
and mathematically is defined as follows: and mathematically is defined as follows:
.. math:: .. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)}) h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`, where :math:`\mathcal{N}(i)` is the set of neighbors of node :math:`i`,
:math:`c_{ij}` is the product of the square root of node degrees :math:`c_{ji}` is the product of the square root of node degrees
(i.e., :math:`c_{ij} = \sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}`), (i.e., :math:`c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|}`),
and :math:`\sigma` is an activation function. and :math:`\sigma` is an activation function.
If a weight tensor on each edge is provided, the weighted graph convolution is defined as:
.. math::
h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
This is NOT equivalent to the weighted graph convolutional network formulation in the paper.
To customize the normalization term :math:`c_{ji}`, one can first set ``norm='none'`` for
the model, and send the pre-normalized :math:`e_{ji}` to the forward computation. We provide
:class:`~dgl.nn.pytorch.EdgeWeightNorm` to normalize scalar edge weight following the GCN paper.
Parameters Parameters
---------- ----------
in_feats : int in_feats : int
...@@ -35,7 +176,7 @@ class GraphConv(nn.Module): ...@@ -35,7 +176,7 @@ class GraphConv(nn.Module):
How to apply the normalizer. If is `'right'`, divide the aggregated messages How to apply the normalizer. If is `'right'`, divide the aggregated messages
by each node's in-degrees, which is equivalent to averaging the received messages. by each node's in-degrees, which is equivalent to averaging the received messages.
If is `'none'`, no normalization is applied. Default is `'both'`, If is `'none'`, no normalization is applied. Default is `'both'`,
where the :math:`c_{ij}` in the paper is applied. where the :math:`c_{ji}` in the paper is applied.
weight : bool, optional weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix. without a weight matrix.
...@@ -185,7 +326,7 @@ class GraphConv(nn.Module): ...@@ -185,7 +326,7 @@ class GraphConv(nn.Module):
""" """
self._allow_zero_in_degree = set_value self._allow_zero_in_degree = set_value
def forward(self, graph, feat, weight=None): def forward(self, graph, feat, weight=None, edge_weight=None):
r""" r"""
Description Description
...@@ -205,6 +346,9 @@ class GraphConv(nn.Module): ...@@ -205,6 +346,9 @@ class GraphConv(nn.Module):
:math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{out}, D_{in_{dst}})`.
weight : torch.Tensor, optional weight : torch.Tensor, optional
Optional external weight tensor. Optional external weight tensor.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns Returns
------- -------
...@@ -243,6 +387,11 @@ class GraphConv(nn.Module): ...@@ -243,6 +387,11 @@ class GraphConv(nn.Module):
'the issue. Setting ``allow_zero_in_degree`` ' 'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will ' 'to be `True` when constructing this module will '
'suppress the check and let the code run.') 'suppress the check and let the code run.')
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite. # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
...@@ -266,14 +415,12 @@ class GraphConv(nn.Module): ...@@ -266,14 +415,12 @@ class GraphConv(nn.Module):
if weight is not None: if weight is not None:
feat_src = th.matmul(feat_src, weight) feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
if weight is not None: if weight is not None:
rst = th.matmul(rst, weight) rst = th.matmul(rst, weight)
......
...@@ -25,6 +25,15 @@ class SAGEConv(nn.Module): ...@@ -25,6 +25,15 @@ class SAGEConv(nn.Module):
h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l}) h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l})
If a weight tensor on each edge is provided, the aggregation becomes:
.. math::
h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`.
Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.
Parameters Parameters
---------- ----------
in_feats : int, or pair of ints in_feats : int, or pair of ints
...@@ -147,7 +156,7 @@ class SAGEConv(nn.Module): ...@@ -147,7 +156,7 @@ class SAGEConv(nn.Module):
_, (rst, _) = self.lstm(m, h) _, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)} return {'neigh': rst.squeeze(0)}
def forward(self, graph, feat): def forward(self, graph, feat, edge_weight=None):
r""" r"""
Description Description
...@@ -164,6 +173,9 @@ class SAGEConv(nn.Module): ...@@ -164,6 +173,9 @@ class SAGEConv(nn.Module):
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns Returns
------- -------
...@@ -179,6 +191,11 @@ class SAGEConv(nn.Module): ...@@ -179,6 +191,11 @@ class SAGEConv(nn.Module):
feat_src = feat_dst = self.feat_drop(feat) feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[:graph.number_of_dst_nodes()]
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
h_self = feat_dst h_self = feat_dst
...@@ -189,23 +206,23 @@ class SAGEConv(nn.Module): ...@@ -189,23 +206,23 @@ class SAGEConv(nn.Module):
if self._aggre_type == 'mean': if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh')) graph.update_all(aggregate_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn': elif self._aggre_type == 'gcn':
check_eq_shape(feat) check_eq_shape(feat)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh')) graph.update_all(aggregate_fn, fn.sum('m', 'neigh'))
# divide in_degrees # divide in_degrees
degs = graph.in_degrees().to(feat_dst) degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool': elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh')) graph.update_all(aggregate_fn, fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm': elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer) graph.update_all(aggregate_fn, self._lstm_reducer)
h_neigh = graph.dstdata['neigh'] h_neigh = graph.dstdata['neigh']
else: else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
......
...@@ -88,6 +88,45 @@ def test_graph_conv(idtype, g, norm, weight, bias): ...@@ -88,6 +88,45 @@ def test_graph_conv(idtype, g, norm, weight, bias):
h_out = conv(g, h, weight=ext_w) h_out = conv(g, h, weight=ext_w)
assert h_out.shape == (ndst, 2) assert h_out.shape == (ndst, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight(idtype, g, norm, weight, bias):
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
ext_w = F.randn((5, 2)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).to(F.ctx())
e_w = g.edata['scalar_w']
if weight:
h_out = conv(g, h, edge_weight=e_w)
else:
h_out = conv(g, h, weight=ext_w, edge_weight=e_w)
assert h_out.shape == (ndst, 2)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right'])
@pytest.mark.parametrize('weight', [True, False])
@pytest.mark.parametrize('bias', [True, False])
def test_graph_conv_e_weight_norm(idtype, g, norm, weight, bias):
g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, 2, norm=norm, weight=weight, bias=bias).to(F.ctx())
ext_w = F.randn((5, 2)).to(F.ctx())
nsrc = g.number_of_src_nodes()
ndst = g.number_of_dst_nodes()
h = F.randn((nsrc, 5)).to(F.ctx())
edgenorm = nn.EdgeWeightNorm(norm=norm)
norm_weight = edgenorm(g, g.edata['scalar_w'])
if weight:
h_out = conv(g, h, edge_weight=norm_weight)
else:
h_out = conv(g, h, weight=ext_w, edge_weight=norm_weight)
assert h_out.shape == (ndst, 2)
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph'])) @pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph']))
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) @pytest.mark.parametrize('norm', ['none', 'both', 'right'])
...@@ -959,6 +998,8 @@ def test_hetero_conv(agg, idtype): ...@@ -959,6 +998,8 @@ def test_hetero_conv(agg, idtype):
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_graph_conv_e_weight()
test_graph_conv_e_weight_norm()
test_set2set() test_set2set()
test_glob_att_pool() test_glob_att_pool()
test_simple_pool() test_simple_pool()
......
...@@ -56,6 +56,14 @@ def graph1(): ...@@ -56,6 +56,14 @@ def graph1():
g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu()) g.edata['w'] = F.copy_to(F.randn((g.number_of_edges(), 3)), F.cpu())
return g return g
@register_case(['homo', 'has_scalar_e_feature'])
def graph1():
g = dgl.graph(([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 6, 6, 7, 8, 9],
[4, 5, 1, 2, 4, 7, 9, 8 ,6, 4, 1, 0, 1, 0, 2, 3, 5]), device=F.cpu())
g.ndata['h'] = F.copy_to(F.randn((g.number_of_nodes(), 2)), F.cpu())
g.edata['scalar_w'] = F.copy_to(F.abs(F.randn((g.number_of_edges(),))), F.cpu())
return g
@register_case(['hetero', 'has_feature']) @register_case(['hetero', 'has_feature'])
def heterograph0(): def heterograph0():
g = dgl.heterograph({ g = dgl.heterograph({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment