"docs/vscode:/vscode.git/clone" did not exist on "0ca82156e1799672a9522e539c65f379bdb91a2a"
Unverified Commit bfd411d0 authored by Tingyu Wang's avatar Tingyu Wang Committed by GitHub
Browse files

[Model] Add `dgl.nn.CuGraphGATConv` model (#5168)



* add CuGraphGATConv model

* lintrunner

* update model to reflect changes in make_mfg_csr(), move max_in_degree to forward()

* simplify pytest markers

* fall back to FG option for large fanout

* update error msg

* add feat_drop and activation options

* add residual option

* Update python/dgl/nn/pytorch/conv/cugraph_gatconv.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update python/dgl/nn/pytorch/conv/cugraph_gatconv.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* reset res_fc

---------
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent f00cd6ef
...@@ -50,6 +50,7 @@ CuGraph Conv Layers ...@@ -50,6 +50,7 @@ CuGraph Conv Layers
:template: classtemplate.rst :template: classtemplate.rst
~dgl.nn.pytorch.conv.CuGraphRelGraphConv ~dgl.nn.pytorch.conv.CuGraphRelGraphConv
~dgl.nn.pytorch.conv.CuGraphGATConv
~dgl.nn.pytorch.conv.CuGraphSAGEConv ~dgl.nn.pytorch.conv.CuGraphSAGEConv
Dense Conv Layers Dense Conv Layers
......
...@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv ...@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv
from .atomicconv import AtomicConv from .atomicconv import AtomicConv
from .cfconv import CFConv from .cfconv import CFConv
from .chebconv import ChebConv from .chebconv import ChebConv
from .cugraph_gatconv import CuGraphGATConv
from .cugraph_relgraphconv import CuGraphRelGraphConv from .cugraph_relgraphconv import CuGraphRelGraphConv
from .cugraph_sageconv import CuGraphSAGEConv from .cugraph_sageconv import CuGraphSAGEConv
from .densechebconv import DenseChebConv from .densechebconv import DenseChebConv
...@@ -67,6 +68,7 @@ __all__ = [ ...@@ -67,6 +68,7 @@ __all__ = [
"EGNNConv", "EGNNConv",
"PNAConv", "PNAConv",
"DGNConv", "DGNConv",
"CuGraphGATConv",
"CuGraphRelGraphConv", "CuGraphRelGraphConv",
"CuGraphSAGEConv", "CuGraphSAGEConv",
] ]
"""Torch Module for graph attention network layer using the aggregation
primitives in cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments
import torch
from torch import nn
try:
from pylibcugraphops import make_fg_csr, make_mfg_csr
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True
class CuGraphGATConv(nn.Module):
r"""Graph attention layer from `Graph Attention Networks
<https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation
accelerated by cugraph-ops.
See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model.
This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
.. note::
This is an **experimental** feature.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection. Defaults: ``False``.
activation : callable activation function/layer or None, optional.
If not None, applies an activation function to the updated node features.
Default: ``None``.
bias : bool, optional
If True, learns a bias term. Defaults: ``True``.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import CuGraphGATConv
>>> device = 'cuda'
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)
>>> g = dgl.add_self_loop(g)
>>> feat = torch.ones(6, 10).to(device)
>>> conv = CuGraphGATConv(10, 2, num_heads=3).to(device)
>>> res = conv(g, feat)
>>> res
tensor([[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]],
[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]],
[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]],
[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]],
[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]],
[[ 0.2340, 1.9226],
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)
"""
MAX_IN_DEGREE_MFG = 500
def __init__(
self,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
bias=True,
):
if has_pylibcugraphops is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
)
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.negative_slope = negative_slope
self.activation = activation
self.fc = nn.Linear(in_feats, out_feats * num_heads, bias=False)
self.attn_weights = nn.Parameter(
torch.Tensor(2 * num_heads * out_feats)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(num_heads * out_feats))
else:
self.register_buffer("bias", None)
if residual:
if in_feats == out_feats * num_heads:
self.res_fc = nn.Identity()
else:
self.res_fc = nn.Linear(
in_feats, out_feats * num_heads, bias=False
)
else:
self.register_buffer("res_fc", None)
self.reset_parameters()
def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(
self.attn_weights.view(2, self.num_heads, self.out_feats), gain=gain
)
if self.bias is not None:
nn.init.zeros_(self.bias)
if isinstance(self.res_fc, nn.Linear):
self.res_fc.reset_parameters()
def forward(self, g, feat, max_in_degree=None):
r"""Forward computation.
Parameters
----------
g : DGLGraph
The graph.
feat : torch.Tensor
Input features of shape :math:`(N, D_{in})`.
max_in_degree : int
Maximum in-degree of destination nodes. It is only effective when
:attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When
:attr:`g` is generated from a neighbor sampler, the value should be
set to the corresponding :attr:`fanout`. If not given,
:attr:`max_in_degree` will be calculated on-the-fly.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where
:math:`H` is the number of heads, and :math:`D_{out}` is size of
output feature.
"""
offsets, indices, _ = g.adj_sparse("csc")
if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()
if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr(
g.dstnodes(),
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]
_graph = make_fg_csr(offsets_fg, indices)
else:
_graph = make_fg_csr(offsets, indices)
feat = self.feat_drop(feat)
feat_transformed = self.fc(feat)
out = GATConvAgg(
feat_transformed,
self.attn_weights,
_graph,
self.num_heads,
"LeakyReLU",
self.negative_slope,
add_own_node=False,
concat_heads=True,
)[: g.num_dst_nodes()].view(-1, self.num_heads, self.out_feats)
feat_dst = feat[: g.num_dst_nodes()]
if self.res_fc is not None:
out = out + self.res_fc(feat_dst).view(
-1, self.num_heads, self.out_feats
)
if self.bias is not None:
out = out + self.bias.view(-1, self.num_heads, self.out_feats)
if self.activation is not None:
out = self.activation(out)
return out
# pylint: disable=too-many-arguments, too-many-locals
from collections import OrderedDict
from itertools import product
import dgl
import pytest
import torch
from dgl.nn import CuGraphGATConv, GATConv
options = OrderedDict(
{
"idtype_int": [False, True],
"max_in_degree": [None, 8],
"num_heads": [1, 3],
"to_block": [False, True],
}
)
def generate_graph():
u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
g = dgl.graph((u, v))
return g
@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_gatconv_equality(idtype_int, max_in_degree, num_heads, to_block):
device = "cuda:0"
in_feat, out_feat = 10, 2
args = (in_feat, out_feat, num_heads)
kwargs = {"bias": False}
g = generate_graph().to(device)
if idtype_int:
g = g.int()
if to_block:
g = dgl.to_block(g)
feat = torch.rand(g.num_src_nodes(), in_feat).to(device)
torch.manual_seed(0)
conv1 = GATConv(*args, **kwargs, allow_zero_in_degree=True).to(device)
out1 = conv1(g, feat)
torch.manual_seed(0)
conv2 = CuGraphGATConv(*args, **kwargs).to(device)
dim = num_heads * out_feat
with torch.no_grad():
conv2.attn_weights.data[:dim] = conv1.attn_l.data.flatten()
conv2.attn_weights.data[dim:] = conv1.attn_r.data.flatten()
conv2.fc.weight.data[:] = conv1.fc.weight.data
out2 = conv2(g, feat, max_in_degree=max_in_degree)
assert torch.allclose(out1, out2, atol=1e-6)
grad_out1 = torch.rand_like(out1)
grad_out2 = grad_out1.clone().detach()
out1.backward(grad_out1)
out2.backward(grad_out2)
assert torch.allclose(conv1.fc.weight.grad, conv2.fc.weight.grad, atol=1e-6)
assert torch.allclose(
torch.cat((conv1.attn_l.grad, conv1.attn_r.grad), dim=0),
conv2.attn_weights.grad.view(2, num_heads, out_feat),
atol=1e-6,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment