Unverified Commit 3aef4677 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Doc] User Guide for Using Edge Weights in Training (#3509)

* Update

* Update

* CI
parent 81915f55
.. _guide-message-passing-edge:
2.4 Apply Edge Weight In Message Passing
----------------------------------------
:ref:`(中文版) <guide_cn-message-passing-edge>`
A commonly seen practice in GNN modeling is to apply edge weight on the
message before message aggregation, for examples, in
`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__ and some `GCN
variants <https://arxiv.org/abs/2004.00445>`__. In DGL, the way to
handle this is:
- Save the weight as edge feature.
- Multiply the edge feature by src node feature in message function.
For example:
.. code::
import dgl.function as fn
# Suppose eweight is a tensor of shape (E, *), where E is the number of edges.
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
The example above uses eweight as the edge weight. The edge weight should
usually be a scalar.
\ No newline at end of file
......@@ -33,7 +33,6 @@ The last section of it explains how to implement message passing on heterogeneou
* :ref:`guide-message-passing-api`
* :ref:`guide-message-passing-efficient`
* :ref:`guide-message-passing-part`
* :ref:`guide-message-passing-edge`
* :ref:`guide-message-passing-heterograph`
.. toctree::
......@@ -44,5 +43,4 @@ The last section of it explains how to implement message passing on heterogeneou
message-api
message-efficient
message-part
message-edge
message-heterograph
.. _guide-training-eweight:
5.5 Use of Edge Weights
----------------------------------
:ref:`(中文版) <guide_cn-training-eweight>`
In a weighted graph, each edge is associated with a semantically meaningful scalar weight. For
example, the edge weights can be connectivity strengths or confidence scores. Naturally, one
may want to utilize edge weights in model development.
Message Passing with Edge Weights
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Most graph neural networks (GNNs) integrate the graph topology information in forward computation
by and only by the message passing mechanism. A message passing operation can be viewed as
a function that takes an adjacency matrix and additional input features as input arguments. For an
unweighted graph, the entries in the adjacency matrix can be zero or one, where a one-valued entry
indicates an edge. If this graph is weighted, the non-zero entries can take arbitrary scalar
values. This is equivalent to multiplying each message by its corresponding edge weight as in
`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__.
With DGL, one can achieve this by:
- Saving the edge weights as an edge feature
- Multplying the original message by the edge feature in the message function
Consider the message passing example with DGL below.
.. code::
import dgl.function as fn
# Suppose graph.ndata['ft'] stores the input node features
graph.update_all(fn.copy_u('ft', 'm'), fn.sum('m', 'ft'))
One can modify it for edge weight support as follows.
.. code::
import dgl.function as fn
# Save edge weights as an edge feature, which is a tensor of shape (E, *)
# E is the number of edges
graph.edata['w'] = eweight
# Suppose graph.ndata['ft'] stores the input node features
graph.update_all(fn.u_mul_e('ft', 'w', 'm'), fn.sum('m', 'ft'))
Using NN Modules with Edge Weights
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
One can modify an NN module for edge weight support by modifying all message passing operations
in it. The following code snippet is an example for NN module supporting edge weights.
.. code::
import dgl.function as fn
import torch.nn as nn
class GNN(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, feat, edge_weight=None):
with g.local_scope():
g.ndata['ft'] = self.linear(feat)
if edge_weight is None:
msg_func = fn.copy_u('ft', 'm')
else:
g.edata['w'] = edge_weight
msg_func = fn.u_mul_e('ft', 'w', 'm')
g.update_all(msg_func, fn.sum('m', 'ft'))
return g.ndata['ft']
DGL's built-in NN modules support edge weights if they take an optional :attr:`edge_weight`
argument in the forward function.
One may need to normalize raw edge weights. In this regard, DGL provides
:func:`~dgl.nn.pytorch.conv.EdgeWeightNorm`.
......@@ -97,6 +97,7 @@ The chapter has four sections, each for one type of graph learning tasks.
* :ref:`guide-training-edge-classification`
* :ref:`guide-training-link-prediction`
* :ref:`guide-training-graph-classification`
* :ref:`guide-training-eweight`
.. toctree::
:maxdepth: 1
......@@ -107,3 +108,4 @@ The chapter has four sections, each for one type of graph learning tasks.
training-edge
training-link
training-graph
training-eweight
.. _guide_cn-message-passing-edge:
2.4 在消息传递中使用边的权重
-----------------------
:ref:`(English Version) <guide-message-passing-edge>`
一类常见的图神经网络建模的做法是在消息聚合前使用边的权重,
比如在 `图注意力网络(GAT) <https://arxiv.org/pdf/1710.10903.pdf>`__ 和一些 `GCN的变种 <https://arxiv.org/abs/2004.00445>`__ 。
DGL的处理方法是:
- 将权重存为边的特征。
- 在消息函数中用边的特征与源节点的特征相乘。
例如:
.. code::
import dgl.function as fn
# 假定eweight是一个形状为(E, *)的张量,E是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
在以上代码中,eweight被用作边的权重。边权重通常是一个标量。
\ No newline at end of file
......@@ -28,7 +28,6 @@
* :ref:`guide_cn-message-passing-api`
* :ref:`guide_cn-message-passing-efficient`
* :ref:`guide_cn-message-passing-part`
* :ref:`guide_cn-message-passing-edge`
* :ref:`guide_cn-message-passing-heterograph`
.. toctree::
......@@ -39,5 +38,4 @@
message-api
message-efficient
message-part
message-edge
message-heterograph
.. _guide_cn-training-eweight:
5.5 使用边权重
----------------------------------
:ref:`(English Version) <guide-training-eweight>`
在一个加权图里,每条边都有一个有意义的标量权重。例如,边权重可以是连接强度或者信心指数。
人们自然会想要在模型开发中使用它们。
使用边权重的消息传递
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
大部分图神经网络在前馈计算中仅通过消息传递引入图结构信息。一个消息传递运算可以视为一个函数。
这个函数的输入变量是一个邻接矩阵和其他输入特征。对于一个不带权重的图,邻接矩阵里的元素不是零就是一。
值为一的元素表示一条边。对于一个加权图,非零的元素可以取任意标量值。这等价于把每条消息和对应的边权重相乘,
即`图注意力网络 <https://arxiv.org/pdf/1710.10903.pdf>`__中的做法。
在DGL里可以通过以下步骤实现这一需求:
- 把边权重保存为一个边特征
- 在消息函数里,用保存的边特征与对应边的原始消息相乘
考虑以下基于DGL的消息传递示例:
.. code::
import dgl.function as fn
# 假定graph.ndata['ft']存储了输入节点特征
graph.update_all(fn.copy_u('ft', 'm'), fn.sum('m', 'ft'))
可以将其按以下方式修改以支持边权重:
.. code::
import dgl.function as fn
# 将边权重保存为一个边特征。边权重是一个形状为(E, *)的张量。
# E是边的数量
graph.edata['w'] = eweight
# 假定graph.ndata['ft']存储了输入节点特征
graph.update_all(fn.u_mul_e('ft', 'w', 'm'), fn.sum('m', 'ft'))
在NN模块中使用边权重
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
用户可以通过修改NN模块中所有的消息传递操作来给NN模块增加边权重支持。以下代码块提供了一个例子。
.. code::
import dgl.function as fn
import torch.nn as nn
class GNN(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, feat, edge_weight=None):
with g.local_scope():
g.ndata['ft'] = self.linear(feat)
if edge_weight is None:
msg_func = fn.copy_u('ft', 'm')
else:
g.edata['w'] = edge_weight
msg_func = fn.u_mul_e('ft', 'w', 'm')
g.update_all(msg_func, fn.sum('m', 'ft'))
return g.ndata['ft']
DGL内置的NN模块如果在forward函数中支持一个可选的:attr:`edge_weight`变量,那么它们已经支持了边权重。
用户可能会需要标准化原始边权重。DGL提供了一个满足这个功能的函数
:func:`~dgl.nn.pytorch.conv.EdgeWeightNorm`。
......@@ -88,6 +88,7 @@
* :ref:`guide_cn-training-edge-classification`
* :ref:`guide_cn-training-link-prediction`
* :ref:`guide_cn-training-graph-classification`
* :ref:`guide_cn-training-graph-eweight`
.. toctree::
:maxdepth: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment