Unverified Commit bcf9923b authored by Tingyu Wang's avatar Tingyu Wang Committed by GitHub
Browse files

[Model] Add `dgl.nn.CuGraphSAGEConv` model (#5137)



* add CuGraphSAGEConv model

* fix lint issues

* update model to reflect changes in make_mfg_csr(), move max_in_degree to forward()

* lintrunner

* allow reset_parameters()

* remove norm option, simplify test

* allow full graph fallback option, add example

* address comments

* address reviews

---------
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent d8ca6317
......@@ -14,7 +14,6 @@ Conv Layers
~dgl.nn.pytorch.conv.GraphConv
~dgl.nn.pytorch.conv.EdgeWeightNorm
~dgl.nn.pytorch.conv.RelGraphConv
~dgl.nn.pytorch.conv.CuGraphRelGraphConv
~dgl.nn.pytorch.conv.TAGConv
~dgl.nn.pytorch.conv.GATConv
~dgl.nn.pytorch.conv.GATv2Conv
......@@ -42,6 +41,17 @@ Conv Layers
~dgl.nn.pytorch.conv.PNAConv
~dgl.nn.pytorch.conv.DGNConv
CuGraph Conv Layers
----------------------------------------
.. autosummary::
:toctree: ../../generated/
:nosignatures:
:template: classtemplate.rst
~dgl.nn.pytorch.conv.CuGraphRelGraphConv
~dgl.nn.pytorch.conv.CuGraphSAGEConv
Dense Conv Layers
----------------------------------------
......
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
from dgl.data import AsNodePredDataset
from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
NeighborSampler,
)
from dgl.nn import CuGraphSAGEConv
from ogb.nodeproppred import DglNodePropPredDataset
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(CuGraphSAGEConv(in_size, hid_size, "mean"))
self.layers.append(CuGraphSAGEConv(hid_size, hid_size, "mean"))
self.layers.append(CuGraphSAGEConv(hid_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def inference(self, g, device, batch_size):
"""Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata["feat"]
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
dataloader = DataLoader(
g,
torch.arange(g.num_nodes()).to(g.device),
sampler,
device=device,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
buffer_device = torch.device("cpu")
pin_memory = buffer_device != device
for l, layer in enumerate(self.layers):
y = torch.empty(
g.num_nodes(),
self.hid_size if l != len(self.layers) - 1 else self.out_size,
device=buffer_device,
pin_memory=pin_memory,
)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = feat[input_nodes]
h = layer(blocks[0], x) # len(blocks) = 1
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
# by design, our output nodes are contiguous
y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
feat = y
return y
def evaluate(model, graph, dataloader):
model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad():
x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x))
num_classes = y_hats[0].shape[1]
return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
def layerwise_infer(device, graph, nid, model, batch_size):
model.eval()
with torch.no_grad():
pred = model.inference(
graph, device, batch_size
) # pred in buffer_device
pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device)
num_classes = pred.shape[1]
return MF.accuracy(
pred, label, task="multiclass", num_classes=num_classes
)
def train(args, device, g, dataset, model):
# create sampler & dataloader
train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler(
[10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
)
use_uva = args.mode == "mixed"
train_dataloader = DataLoader(
g,
train_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)
val_dataloader = DataLoader(
g,
val_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(
train_dataloader
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
total_loss += loss.item()
acc = evaluate(model, g, val_dataloader)
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item()
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
default="mixed",
choices=["mixed", "puregpu"],
help="Training mode. 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
)
args = parser.parse_args()
if not torch.cuda.is_available():
args.mode = "cpu"
print(f"Training in {args.mode} mode.")
# load and preprocess dataset
print("Loading data")
dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu")
device = torch.device("cpu" if args.mode == "cpu" else "cuda")
# create GraphSAGE model
in_size = g.ndata["feat"].shape[1]
out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device)
# model training
print("Training...")
train(args, device, g, dataset, model)
# test the model
print("Testing...")
acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item()))
......@@ -7,6 +7,7 @@ from .atomicconv import AtomicConv
from .cfconv import CFConv
from .chebconv import ChebConv
from .cugraph_relgraphconv import CuGraphRelGraphConv
from .cugraph_sageconv import CuGraphSAGEConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
......@@ -67,4 +68,5 @@ __all__ = [
"PNAConv",
"DGNConv",
"CuGraphRelGraphConv",
"CuGraphSAGEConv",
]
"""Torch Module for GraphSAGE layer using the aggregation primitives in
cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments
import torch
from torch import nn
try:
from pylibcugraphops import make_fg_csr, make_mfg_csr
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
except ImportError:
has_pylibcugraphops = False
else:
has_pylibcugraphops = True
class CuGraphSAGEConv(nn.Module):
r"""An accelerated GraphSAGE layer from `Inductive Representation Learning
on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the
highly-optimized aggregation primitives in cugraph-ops:
.. math::
h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate}
\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)
h_{i}^{(l+1)} &= W \cdot \mathrm{concat}
(h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)})
This module depends on :code:`pylibcugraphops` package, which can be
installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`.
.. note::
This is an **experimental** feature.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
aggregator_type : str
Aggregator type to use (``mean``, ``sum``, ``min``, ``max``).
feat_drop : float
Dropout rate on features, default: ``0``.
bias : bool
If True, adds a learnable bias to the output. Default: ``True``.
Examples
--------
>>> import dgl
>>> import torch
>>> from dgl.nn import CuGraphSAGEConv
>>> device = 'cuda'
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)
>>> g = dgl.add_self_loop(g)
>>> feat = torch.ones(6, 10).to(device)
>>> conv = CuGraphSAGEConv(10, 2, 'mean').to(device)
>>> res = conv(g, feat)
>>> res
tensor([[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952],
[-1.1690, 0.1952]], device='cuda:0', grad_fn=<AddmmBackward0>)
"""
MAX_IN_DEGREE_MFG = 500
def __init__(
self,
in_feats,
out_feats,
aggregator_type="mean",
feat_drop=0.0,
bias=True,
):
if has_pylibcugraphops is False:
raise ModuleNotFoundError(
f"{self.__class__.__name__} requires pylibcugraphops >= 23.02. "
f"Install via `conda install -c nvidia 'pylibcugraphops>=23.02'`."
)
valid_aggr_types = {"max", "min", "mean", "sum"}
if aggregator_type not in valid_aggr_types:
raise ValueError(
f"Invalid aggregator_type. Must be one of {valid_aggr_types}. "
f"But got '{aggregator_type}' instead."
)
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.aggr = aggregator_type
self.feat_drop = nn.Dropout(feat_drop)
self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias)
def reset_parameters(self):
r"""Reinitialize learnable parameters."""
self.linear.reset_parameters()
def forward(self, g, feat, max_in_degree=None):
r"""Forward computation.
Parameters
----------
g : DGLGraph
The graph.
feat : torch.Tensor
Node features. Shape: :math:`(N, D_{in})`.
max_in_degree : int
Maximum in-degree of destination nodes. It is only effective when
:attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When
:attr:`g` is generated from a neighbor sampler, the value should be
set to the corresponding :attr:`fanout`. If not given,
:attr:`max_in_degree` will be calculated on-the-fly.
Returns
-------
torch.Tensor
Output node features. Shape: :math:`(N, D_{out})`.
"""
offsets, indices, _ = g.adj_sparse("csc")
if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()
if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = make_mfg_csr(
g.dstnodes(),
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
)
else:
offsets_fg = torch.empty(
g.num_src_nodes() + 1,
dtype=offsets.dtype,
device=offsets.device,
)
offsets_fg[: offsets.numel()] = offsets
offsets_fg[offsets.numel() :] = offsets[-1]
_graph = make_fg_csr(offsets_fg, indices)
else:
_graph = make_fg_csr(offsets, indices)
feat = self.feat_drop(feat)
h = SAGEConvAgg(feat, _graph, self.aggr)[: g.num_dst_nodes()]
h = self.linear(h)
return h
# pylint: disable=too-many-arguments, too-many-locals
from collections import OrderedDict
from itertools import product
import dgl
import pytest
import torch
from dgl.nn import CuGraphSAGEConv, SAGEConv
options = OrderedDict(
{
"idtype_int": [False, True],
"max_in_degree": [None, 8],
"to_block": [False, True],
}
)
def generate_graph():
u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9])
v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0])
g = dgl.graph((u, v))
return g
@pytest.mark.skip()
@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
device = "cuda:0"
in_feat, out_feat = 5, 2
kwargs = {"aggregator_type": "mean"}
g = generate_graph().to(device)
if idtype_int:
g = g.int()
if to_block:
g = dgl.to_block(g)
feat = torch.rand(g.num_src_nodes(), in_feat).to(device)
torch.manual_seed(0)
conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device)
torch.manual_seed(0)
conv2 = CuGraphSAGEConv(in_feat, out_feat, **kwargs).to(device)
with torch.no_grad():
conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data
conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data
conv2.linear.bias.data[:] = conv1.fc_self.bias.data
out1 = conv1(g, feat)
out2 = conv2(g, feat, max_in_degree=max_in_degree)
assert torch.allclose(out1, out2, atol=1e-06)
grad_out = torch.rand_like(out1)
out1.backward(grad_out)
out2.backward(grad_out)
assert torch.allclose(
conv1.fc_neigh.weight.grad,
conv2.linear.weight.grad[:, :in_feat],
atol=1e-6,
)
assert torch.allclose(
conv1.fc_self.weight.grad,
conv2.linear.weight.grad[:, in_feat:],
atol=1e-6,
)
assert torch.allclose(
conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment