Unverified Commit 49a4436a authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[MPOps] Add edge-wise message passing operators u_op_v (#4801)

* edgewise mpops

* formatting

* black formatting

* isort

* fix lint

* remove typing module; address comments

* fix gpu ut; rename test file
parent dccf1f16
......@@ -51,6 +51,7 @@ from . import optim
from .frame import LazyFeature
from .utils import apply_each
from .global_config import is_libxsmm_enabled, use_libxsmm
from .mpops import *
from ._deprecate.graph import DGLGraph as DGLGraphStale
from ._deprecate.nodeflow import *
"""Message passing operator sub-package"""
from .edgewise import *
from .nodewise import *
from .fused import *
"""Operators for computing edge data."""
import sys
from .. import ops
__all__ = ["copy_u", "copy_v"]
#######################################################
# Edge-wise operators that fetch node data to edges
#######################################################
def copy_u(g, x_node, etype = None):
"""Compute new edge data by fetching from source node data.
Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph
:math:`G(V_{src}, V_{dst}, E)`) and an input tensor :math:`X`,
the operator computes a tensor :math:`Y` storing the new edge data.
For each edge :math:`e=(u,v) \\in E`, it computes:
.. math:
Y_e = X_u
Parameters
----------
g : DGLGraph
The input graph.
x_node : Tensor
The tensor storing the source node data. Shape :math:`(|V_{src}|, *)`.
etype : str or (str, str, str), optional
Edge type. If not specified, the input graph must have only one type of
edges.
Returns
-------
Tensor
The tensor storing the new edge data. Shape :math:`(|E|, *)`.
Examples
--------
**Homogeneous graph**
>>> import torch, dgl
>>> g = dgl.rand_graph(100, 500) # a random graph of 100 nodes, 500 edges
>>> x = torch.randn(g.num_nodes(), 5) # 5 features
>>> y = dgl.copy_u(g, x)
>>> print(y.shape)
(500, 5)
**Heterogeneous graph**
>>> hg = dgl.heterograph({
... ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),
... ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])
... })
>>> x = torch.randn(hg.num_nodes('user'), 5)
>>> y = dgl.copy_u(hg, x, etype='like')
>>> print(y.shape)
(4, 5)
"""
etype_subg = g if etype is None else g[etype]
return ops.gsddmm(etype_subg, "copy_lhs", x_node, None)
def copy_v(g, x_node, etype = None):
"""Compute new edge data by fetching from destination node data.
Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph
:math:`G(V_{src}, V_{dst}, E)`) and an input tensor :math:`X`,
the operator computes a tensor :math:`Y` storing the new edge data.
For each edge :math:`e=(u,v) \\in E`, it computes:
.. math:
Y_e = X_v
Parameters
----------
g : DGLGraph
The input graph.
x_node : Tensor
The tensor storing the destination node data. Shape :math:`(|V_{dst}|, *)`.
etype : str or (str, str, str), optional
Edge type. If not specified, the input graph must have
only one type of edges.
Returns
-------
Tensor
The tensor storing the new edge data. Shape :math:`(|E|, *)`.
Examples
--------
**Homogeneous graph**
>>> import torch, dgl
>>> g = dgl.rand_graph(100, 500) # a random graph of 100 nodes, 500 edges
>>> x = torch.randn(g.num_nodes(), 5) # 5 features
>>> y = dgl.copy_v(g, x)
>>> print(y.shape)
(500, 5)
**Heterogeneous graph**
>>> hg = dgl.heterograph({
... ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),
... ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])
... })
>>> x = torch.randn(hg.num_nodes('movie'), 5)
>>> y = dgl.copy_v(hg, x, etype='like')
>>> print(y.shape)
(4, 5)
"""
etype_subg = g if etype is None else g[etype]
return ops.gsddmm(etype_subg, "copy_rhs", None, x_node)
#######################################################
# Binary edge-wise operators
#######################################################
def _gen_u_op_v(op):
"""Internal helper function to create binary edge-wise operators.
The function will return a Python function with:
- Name: u_{op}_v
- Docstring template
Parameters
----------
op : str
Binary operator name. Must be 'add', 'sub', 'mul', 'div' or 'dot'.
"""
name = f"u_{op}_v"
op_verb = {
"add": "adding",
"sub": "subtracting",
"mul": "multiplying",
"div": "dividing",
"dot": "dot-product",
}
docstring = f"""Compute new edge data by {op_verb[op]} the source node data
and destination node data.
Given an input graph :math:`G(V, E)` (or a unidirectional bipartite graph
:math:`G(V_{{src}}, V_{{dst}}, E)`) and two input tensors :math:`X` and
:math:`Y`, the operator computes a tensor :math:`Z` storing the new edge data.
For each edge :math:`e=(u,v) \\in E`, it computes:
.. math:
Z_e = {op}(X_u, Y_v)
If :math:`X_u` and :math:`Y_v` are vectors or high-dimensional tensors, the
operation is element-wise and supports shape broadcasting. Read more about
`NumPy's broadcasting semantics
<https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
g : DGLGraph
The input graph.
x_node : Tensor
The tensor storing the source node data. Shape :math:`(|V_{{src}}|, *)`.
y_node : Tensor
The tensor storing the destination node data. Shape :math:`(|V_{{dst}}|, *)`.
etype : str or (str, str, str), optional
Edge type. If not specified, the input graph must have
only one type of edges.
Returns
-------
Tensor
The tensor storing the new edge data. Shape :math:`(|E|, *)`.
Examples
--------
**Homogeneous graph**
>>> import torch, dgl
>>> g = dgl.rand_graph(100, 500) # a random graph of 100 nodes, 500 edges
>>> x = torch.randn(g.num_nodes(), 5) # 5 features
>>> y = torch.randn(g.num_nodes(), 5) # 5 features
>>> z = dgl.{name}(g, x, y)
>>> print(z.shape)
(500, 5)
**Heterogeneous graph**
>>> hg = dgl.heterograph({{
... ('user', 'follow', 'user'): ([0, 1, 2], [2, 3, 4]),
... ('user', 'like', 'movie'): ([3, 3, 1, 2], [0, 0, 1, 1])
... }})
>>> x = torch.randn(hg.num_nodes('user'), 5)
>>> y = torch.randn(hg.num_nodes('user'), 5)
>>> z = dgl.{name}(hg, x, y, etype='follow')
>>> print(z.shape)
(3, 5)
**Shape broadcasting**
>>> x = torch.randn(g.num_nodes(), 5) # 5 features
>>> y = torch.randn(g.num_nodes(), 1) # one feature
>>> z = dgl.{name}(g, x, y)
>>> print(z.shape)
(500, 5)
"""
def func(g, x_node, y_node, etype = None):
etype_subg = g if etype is None else g[etype]
return ops.gsddmm(etype_subg, op, x_node, y_node, lhs_target="u", rhs_target="v")
func.__name__ = name
func.__doc__ = docstring
return func
def _register_func(func):
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_func(_gen_u_op_v("add"))
_register_func(_gen_u_op_v("sub"))
_register_func(_gen_u_op_v("mul"))
_register_func(_gen_u_op_v("div"))
_register_func(_gen_u_op_v("dot"))
"""Operators that fuse the computation and aggregation of edge data."""
"""Operators for aggregating/reducing edge data to node data."""
"""dgl spmm operator module."""
"""Internal module for general spmm operators."""
import sys
from .. import backend as F
......
import random
import backend as F
import numpy as np
import pytest
import torch
from test_utils import parametrize_idtype
import dgl
random.seed(42)
np.random.seed(42)
dgl.seed(42)
torch.random.manual_seed(42)
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_u(idtype, feat_size):
ctx = F.ctx()
g = dgl.rand_graph(30, 100)
g = g.astype(idtype).to(ctx)
x = torch.randn((g.num_nodes(),) + feat_size, requires_grad=True, device=ctx)
y = dgl.copy_u(g, x)
y.sum().backward()
x_grad = x.grad
x.grad.zero_()
u, v = g.edges()
y_true = x[u.long()]
y_true.sum().backward()
x_grad_true = x.grad
assert torch.allclose(y, y_true)
assert torch.allclose(x_grad, x_grad_true)
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_u_hetero(idtype, feat_size):
ctx = F.ctx()
hg = dgl.heterograph(
{
("user", "follow", "user"): ([0, 1, 2], [2, 3, 4]),
("user", "like", "movie"): ([3, 3, 1, 2], [0, 0, 1, 1]),
}
)
hg = hg.astype(idtype).to(ctx)
x = torch.randn((hg.num_nodes("user"),) + feat_size, requires_grad=True, device=ctx)
y = dgl.copy_u(hg, x, etype="like")
y.sum().backward()
x_grad = x.grad
x.grad.zero_()
u, v = hg.edges(etype="like")
y_true = x[u.long()]
y_true.sum().backward()
x_grad_true = x.grad
assert torch.allclose(y, y_true)
assert torch.allclose(x_grad, x_grad_true)
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_v(idtype, feat_size):
ctx = F.ctx()
g = dgl.rand_graph(30, 100)
g = g.astype(idtype).to(ctx)
x = torch.randn((g.num_nodes(),) + feat_size, requires_grad=True, device=ctx)
y = dgl.copy_v(g, x)
y.sum().backward()
x_grad = x.grad
x.grad.zero_()
u, v = g.edges()
y_true = x[v.long()]
y_true.sum().backward()
x_grad_true = x.grad
assert torch.allclose(y, y_true)
assert torch.allclose(x_grad, x_grad_true)
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_v_hetero(idtype, feat_size):
ctx = F.ctx()
hg = dgl.heterograph(
{
("user", "follow", "user"): ([0, 1, 2], [2, 3, 4]),
("user", "like", "movie"): ([3, 3, 1, 2], [0, 0, 1, 1]),
}
)
hg = hg.astype(idtype).to(ctx)
x = torch.randn((hg.num_nodes("movie"),) + feat_size, requires_grad=True, device=ctx)
y = dgl.copy_v(hg, x, etype="like")
y.sum().backward()
x_grad = x.grad
x.grad.zero_()
u, v = hg.edges(etype="like")
y_true = x[v.long()]
y_true.sum().backward()
x_grad_true = x.grad
assert torch.allclose(y, y_true)
assert torch.allclose(x_grad, x_grad_true)
binary_arg_sizes = [
((5,), (5,)),
((5,), ()),
((), (5,)),
((1, 3, 3), (4, 1, 3)),
((3, 3), (4, 1, 3)),
((4, 1, 3), (3, 3)),
]
dot_arg_sizes = [
((5,), (5,)),
((1, 3, 3), (4, 1, 3)),
((3, 3), (4, 1, 3)),
((4, 1, 3), (3, 3)),
]
ops = ["add", "sub", "mul", "div"]
def pad_shape(x, y, x_size, y_size):
xy_size = torch.broadcast_shapes(x_size, y_size)
new_x_size = (1,) * (len(xy_size) - len(x_size)) + x_size
new_y_size = (1,) * (len(xy_size) - len(y_size)) + y_size
new_x = x.view(-1, *new_x_size)
new_y = y.view(-1, *new_y_size)
return new_x, new_y
@parametrize_idtype
@pytest.mark.parametrize("op", ops)
@pytest.mark.parametrize("x_size,y_size", binary_arg_sizes)
def test_u_op_v(idtype, op, x_size, y_size):
ctx = F.ctx()
g = dgl.rand_graph(30, 100)
g = g.astype(idtype).to(ctx)
x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)
y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)
f_dgl = getattr(dgl, f"u_{op}_v")
z = f_dgl(g, x, y)
z.sum().backward()
x_grad = x.grad
y_grad = y.grad
x_grad.zero_()
y_grad.zero_()
u, v = g.edges()
f_torch = getattr(torch, op)
x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)
z_true = f_torch(x_u, y_v)
z_true.sum().backward()
x_grad_true = x.grad
y_grad_true = y.grad
assert torch.allclose(z, z_true)
assert torch.allclose(x_grad, x_grad_true)
assert torch.allclose(y_grad, y_grad_true)
@parametrize_idtype
@pytest.mark.parametrize("x_size,y_size", dot_arg_sizes)
def test_u_dot_v(idtype, x_size, y_size):
ctx = F.ctx()
g = dgl.rand_graph(30, 100)
g = g.astype(idtype).to(ctx)
x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)
y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)
z = dgl.u_dot_v(g, x, y)
z.sum().backward()
x_grad = x.grad
y_grad = y.grad
x_grad.zero_()
y_grad.zero_()
u, v = g.edges()
x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)
z_true = (x_u * y_v).sum(-1).unsqueeze(-1)
z_true.sum().backward()
x_grad_true = x.grad
y_grad_true = y.grad
assert torch.allclose(z, z_true, atol=1e-4, rtol=1e-4)
assert torch.allclose(x_grad, x_grad_true)
assert torch.allclose(y_grad, y_grad_true)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment