"src/graph/vscode:/vscode.git/clone" did not exist on "7b3a7b14381acf7d5d8213e3e36a94fdf69c827b"
Unverified Commit ba21295c authored by Tingyu Wang's avatar Tingyu Wang Committed by GitHub
Browse files

[Model] Update `CuGraphRelGraphConv` to use new bindings from `pylibcugraphops` (#4965)



* update agg function with new bindings

* handle optional import in __init__

* raise error in RelGraphConvAgg when pylibcugraphops not imported

* Update tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* Update tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

* use keyword args for readability

* add missing docstring to pass CI

* catch ImportError rather than ModuleNotFoundError
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent fb223d47
......@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv
from .atomicconv import AtomicConv
from .cfconv import CFConv
from .chebconv import ChebConv
from .cugraph_relgraphconv import CuGraphRelGraphConv
from .densechebconv import DenseChebConv
from .densegraphconv import DenseGraphConv
from .densesageconv import DenseSAGEConv
......@@ -65,11 +66,5 @@ __all__ = [
"EGNNConv",
"PNAConv",
"DGNConv",
"CuGraphRelGraphConv",
]
try:
from .cugraph_relgraphconv import CuGraphRelGraphConv
except ImportError:
pass
else:
__all__.append('CuGraphRelGraphConv')
......@@ -7,21 +7,27 @@ import torch as th
from torch import nn
try:
from pylibcugraphops.legacy.aggregators.node_level import (
agg_hg_basis_post_bwd_int32,
agg_hg_basis_post_bwd_int64,
agg_hg_basis_post_fwd_int32,
agg_hg_basis_post_fwd_int64,
from pylibcugraphops import make_mfg_csr_hg
from pylibcugraphops.operators import (
agg_hg_basis_mfg_n2n_post_bwd as agg_bwd,
)
from pylibcugraphops.legacy.structure.graph_types import (
message_flow_graph_hg_csr_int32,
message_flow_graph_hg_csr_int64,
from pylibcugraphops.operators import (
agg_hg_basis_mfg_n2n_post_fwd as agg_fwd,
)
except ModuleNotFoundError:
raise ModuleNotFoundError(
"dgl.nn.CuGraphRelGraphConv requires pylibcugraphops to be installed."
except ImportError:
has_pylibcugraphops = False
def make_mfg_csr_hg(*args):
r"""A dummy function to help raise error in RelGraphConvAgg when
pylibcugraphops is not found."""
raise NotImplementedError(
"RelGraphConvAgg requires pylibcugraphops to be installed."
)
else:
has_pylibcugraphops = True
class RelGraphConvAgg(th.autograd.Function):
r"""Custom autograd function for R-GCN aggregation layer that uses the
......@@ -57,46 +63,31 @@ class RelGraphConvAgg(th.autograd.Function):
num_rels * in_feat) when ``coeff=None``; Shape: (num_dst_nodes,
num_bases * in_feat) otherwise.
"""
if g.idtype == th.int32:
mfg_csr_func = message_flow_graph_hg_csr_int32
agg_fwd_func = agg_hg_basis_post_fwd_int32
elif g.idtype == th.int64:
mfg_csr_func = message_flow_graph_hg_csr_int64
agg_fwd_func = agg_hg_basis_post_fwd_int64
else:
raise TypeError(
f"Supported ID type: torch.int32 or torch.int64, but got "
f"{g.idtype}."
)
ctx.idtype = g.idtype
_in_feat = feat.shape[-1]
in_feat = feat.shape[-1]
indptr, indices, edge_ids = g.adj_sparse("csc")
# Edge_ids is in a mixed order, need to permutate incoming etypes.
ctx.edge_types_int32 = edge_types[edge_ids.long()].int()
# Node_types are not being used in agg_fwd_func.
_num_node_types = 0
_out_node_types = _in_node_types = None
ctx.edge_types_perm = edge_types[edge_ids.long()].int()
mfg = mfg_csr_func(
max_in_degree,
mfg = make_mfg_csr_hg(
g.dstnodes(),
g.srcnodes(),
indptr,
indices,
_num_node_types,
num_rels,
_out_node_types,
_in_node_types,
ctx.edge_types_int32,
max_in_degree,
n_node_types=0,
n_edge_types=num_rels,
out_node_types=None,
in_node_types=None,
edge_types=ctx.edge_types_perm,
)
ctx.mfg = mfg
if coeff is None:
leading_dimension = num_rels * _in_feat
leading_dimension = num_rels * in_feat
else:
_num_bases = coeff.shape[-1]
leading_dimension = _num_bases * _in_feat
num_bases = coeff.shape[-1]
leading_dimension = num_bases * in_feat
agg_output = th.empty(
g.num_dst_nodes(),
......@@ -104,15 +95,11 @@ class RelGraphConvAgg(th.autograd.Function):
dtype=th.float32,
device=feat.device,
)
if coeff is None:
agg_fwd_func(agg_output, feat.detach(), mfg)
agg_fwd(agg_output, feat.detach(), None, mfg)
else:
agg_fwd_func(
agg_output,
feat.detach(),
mfg,
weights_combination=coeff.detach()
)
agg_fwd(agg_output, feat.detach(), coeff.detach(), mfg)
ctx.save_for_backward(feat, coeff)
return agg_output
......@@ -130,26 +117,19 @@ class RelGraphConvAgg(th.autograd.Function):
"""
feat, coeff = ctx.saved_tensors
if ctx.idtype == th.int32:
agg_bwd_func = agg_hg_basis_post_bwd_int32
else:
agg_bwd_func = agg_hg_basis_post_bwd_int64
grad_feat = th.empty_like(feat)
grad_coeff = None if coeff is None else th.empty_like(coeff)
grad_feat = th.empty_like(feat, dtype=th.float32, device=feat.device)
if coeff is None:
grad_coeff = None
agg_bwd_func(grad_feat, grad_output, feat.detach(), ctx.mfg)
agg_bwd(grad_feat, None, grad_output, feat.detach(), None, ctx.mfg)
else:
grad_coeff = th.empty_like(
coeff, dtype=th.float32, device=coeff.device
)
agg_bwd_func(
agg_bwd(
grad_feat,
grad_coeff,
grad_output,
feat.detach(),
coeff.detach(),
ctx.mfg,
output_weight_gradient=grad_coeff,
weights_combination=coeff.detach()
)
return None, None, None, None, grad_feat, grad_coeff
......@@ -225,6 +205,7 @@ class CuGraphRelGraphConv(nn.Module):
[-1.4335, -2.3758],
[-1.4331, -2.3295]], device='cuda:0', grad_fn=<AddBackward0>)
"""
def __init__(
self,
in_feat,
......@@ -237,8 +218,13 @@ class CuGraphRelGraphConv(nn.Module):
self_loop=True,
dropout=0.0,
layer_norm=False,
max_in_degree=None
max_in_degree=None,
):
if has_pylibcugraphops is False:
raise ModuleNotFoundError(
"dgl.nn.CuGraphRelGraphConv requires pylibcugraphops "
"to be installed."
)
super().__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......
......@@ -4,9 +4,10 @@ import dgl
from dgl.nn import CuGraphRelGraphConv
from dgl.nn import RelGraphConv
# TODO(tingyu66): Re-enable the following tests after updating cuGraph CI image.
use_longs = [False, True]
max_in_degrees = [None, 8]
# TODO(tingyu66): add back 'None' to regularizers when re-enabling CI
regularizers = ["basis"]
regularizers = [None, "basis"]
device = "cuda"
......@@ -18,9 +19,11 @@ def generate_graph():
g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),))
return g
@pytest.mark.skip()
@pytest.mark.parametrize('use_long', use_longs)
@pytest.mark.parametrize('max_in_degree', max_in_degrees)
@pytest.mark.parametrize("regularizer", regularizers)
def test_full_graph(max_in_degree, regularizer):
def test_full_graph(use_long, max_in_degree, regularizer):
in_feat, out_feat, num_rels, num_bases = 10, 2, 3, 2
kwargs = {
"num_bases": num_bases,
......@@ -29,6 +32,10 @@ def test_full_graph(max_in_degree, regularizer):
"self_loop": False,
}
g = generate_graph().to(device)
if use_long:
g = g.long()
else:
g = g.int()
feat = torch.ones(g.num_nodes(), in_feat).to(device)
torch.manual_seed(0)
......@@ -53,6 +60,7 @@ def test_full_graph(max_in_degree, regularizer):
conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6
)
@pytest.mark.skip()
@pytest.mark.parametrize('max_in_degree', max_in_degrees)
@pytest.mark.parametrize("regularizer", regularizers)
def test_mfg(max_in_degree, regularizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment