Unverified Commit 16eba6e8 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Add biased multi-head attention module (dense) (#4916)



* Add biased multi-head attention module (dense)

* fix lint issues

* refine according to dongyu's comments
Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 65b34702
......@@ -117,6 +117,7 @@ Utility Modules
~dgl.nn.pytorch.utils.LabelPropagation
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
Network Embedding Modules
----------------------------------------
......
"""Torch modules for graph transformers."""
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
__all__ = ["DegreeEncoder", "BiasedMultiheadAttention"]
class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in
......@@ -83,3 +85,137 @@ class DegreeEncoder(nn.Module):
)
return degree_embedding
class BiasedMultiheadAttention(nn.Module):
r"""Dense Multi-Head Attention Module with Graph Attention Bias.
Compute attention between nodes with attention bias obtained from graph
structures, as introduced in `Do Transformers Really Perform Bad for
Graph Representation? <https://arxiv.org/pdf/2106.05234>`__
.. math::
\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)
:math:`Q` and :math:`K` are feature representation of nodes. :math:`d`
is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which
can be additive or multiplicative according to the operator :math:`\circ`.
Parameters
----------
feat_size : int
Feature size.
num_heads : int
Number of attention heads, by which attr:`feat_size` is divisible.
bias : bool, optional
If True, it uses bias for linear projection. Default: True.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
attn_drop : float, optional
Dropout probability on attention weights. Defalt: 0.1.
Examples
--------
>>> import torch as th
>>> from dgl.nn import BiasedMultiheadAttention
>>> ndata = th.rand(16, 100, 512)
>>> bias = th.rand(16, 100, 100, 8)
>>> net = BiasedMultiheadAttention(feat_size=512, num_heads=8)
>>> out = net(ndata, bias)
"""
def __init__(self, feat_size, num_heads, bias=True, attn_bias_type="add", attn_drop=0.1):
super().__init__()
self.feat_size = feat_size
self.num_heads = num_heads
self.head_dim = feat_size // num_heads
assert (
self.head_dim * num_heads == feat_size
), "feat_size must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.attn_bias_type = attn_bias_type
self.q_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.k_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.v_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.out_proj = nn.Linear(feat_size, feat_size, bias=bias)
self.dropout = nn.Dropout(p=attn_drop)
self.reset_parameters()
def reset_parameters(self):
"""Reset parameters of projection matrices, the same settings as that in Graphormer.
"""
nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-0.5)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(self, ndata, attn_bias=None, attn_mask=None):
"""Forward computation.
Parameters
----------
ndata : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid positions, where
invalid positions are indicated by non-zero values. Shape: (batch_size, N, N).
Returns
-------
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""
q_h = self.q_proj(ndata).transpose(0, 1)
k_h = self.k_proj(ndata).transpose(0, 1)
v_h = self.v_proj(ndata).transpose(0, 1)
bsz, N, _ = ndata.shape
q_h = q_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1) / self.scaling
k_h = k_h.reshape(N, bsz * self.num_heads, self.head_dim).permute(1, 2, 0)
v_h = v_h.reshape(N, bsz * self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = (
th.bmm(q_h, k_h)
.transpose(0, 2)
.reshape(N, N, bsz, self.num_heads)
.transpose(0, 2)
)
if attn_bias is not None:
if self.attn_bias_type == "add":
attn_weights += attn_bias
else:
attn_weights *= attn_bias
if attn_mask is not None:
attn_weights[attn_mask.to(th.bool)] = float("-inf")
attn_weights = F.softmax(
attn_weights.transpose(0, 2)
.reshape(N, N, bsz * self.num_heads)
.transpose(0, 2),
dim=2,
)
attn_weights = self.dropout(attn_weights)
attn = th.bmm(attn_weights, v_h).transpose(0, 1)
attn = self.out_proj(attn.reshape(N, bsz, self.feat_size).transpose(0, 1))
return attn
......@@ -1786,3 +1786,18 @@ def test_LaplacianPosEnc(num_layer, k, lpe_dim, n_head, batch_norm, num_post_lay
model = nn.LaplacianPosEnc("DeepSet", num_layer, k, lpe_dim,
batch_norm=batch_norm, num_post_layer=num_post_layer).to(ctx)
assert model(EigVals, EigVecs).shape == (num_nodes, lpe_dim)
@pytest.mark.parametrize('feat_size', [128, 512])
@pytest.mark.parametrize('num_heads', [8, 16])
@pytest.mark.parametrize('bias', [True, False])
@pytest.mark.parametrize('attn_bias_type', ['add', 'mul'])
@pytest.mark.parametrize('attn_drop', [0.1, 0.5])
def test_BiasedMultiheadAttention(feat_size, num_heads, bias, attn_bias_type, attn_drop):
ndata = th.rand(16, 100, feat_size)
attn_bias = th.rand(16, 100, 100, num_heads)
attn_mask = th.rand(16, 100, 100) < 0.5
net = nn.BiasedMultiheadAttention(feat_size, num_heads, bias, attn_bias_type, attn_drop)
out = net(ndata, attn_bias, attn_mask)
assert out.shape == (16, 100, feat_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment