Unverified Commit 88f5a8be authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add heterogeneous graph API for `edge_softmax` (#3571)



* edge_softmax_hetero forwar+cpu+norm=dst

* convert eids to list

* addedunittest

* added unittest

* added backward. Not tested correctness

* minor

* changed reducer to max from sum

* bugfix

* docstring

* add GPU unittest

* output converted to dict from tuple

* lint check
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 25538ba4
...@@ -1625,6 +1625,43 @@ def edge_softmax(gidx, logits, eids, norm_by): ...@@ -1625,6 +1625,43 @@ def edge_softmax(gidx, logits, eids, norm_by):
""" """
pass pass
def edge_softmax_hetero(gidx, eids, norm_by, *logits):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
By default edge softmax is normalized by destination nodes(i.e. :math:`ij`
are incoming edges of `i` in the formula above). We also support edge
softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of
`i` in the formula). The previous case correspond to softmax in GAT and
Transformer, and the later case correspond to softmax in Capsule network.
Parameters
----------
gidx : HeteroGraphIndex
The graph to perfor edge softmax on.
eids : dict of tensors
Each tensor has the edges on which to apply edge softmax for a
corresponsing relation type.
logits : tuple of tensors
The input edge features of different relation types.
norm_by : str, could be `src` or `dst`
Normalized by source nodes or destination nodes. Default: `dst`.
Returns
-------
Tensor
Softmax value
"""
pass
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
"""Segment reduction operator. """Segment reduction operator.
......
...@@ -26,8 +26,8 @@ else: ...@@ -26,8 +26,8 @@ else:
return bwd(*args, **kwargs) return bwd(*args, **kwargs)
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'segment_reduce', 'scatter_add', __all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'csrmm', 'csrsum', 'csrmask'] 'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -501,10 +501,81 @@ class EdgeSoftmax(th.autograd.Function): ...@@ -501,10 +501,81 @@ class EdgeSoftmax(th.autograd.Function):
out, = ctx.saved_tensors out, = ctx.saved_tensors
sds = out * grad_out sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds) accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v') grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
return None, grad_score, None, None return None, grad_score, None, None
class EdgeSoftmax_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, eids, norm_by, *score):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
# remember to save the graph to backward cache before making it
# a local variable
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
feats = tuple(lhs + list(score))
score_max = _gspmm_hetero(gidx, 'copy_rhs', 'max', u_len, feats)[0]
out_tmp = _gsddmm_hetero(gidx, 'sub', e_len, 'e', 'v', tuple(list(score) + list(score_max)))
score = tuple([th.exp(out_tmp[i]) if out_tmp[i] is not None else None
for i in range(len(out_tmp))])
score_sum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(score)))[0]
out = _gsddmm_hetero(gidx, 'div', e_len, 'e', 'v', tuple(list(score) + list(score_sum)))
ctx.backward_cache = gidx
ctx.save_for_backward(*out)
return out
@staticmethod
@custom_bwd
def backward(ctx, *grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - out * sds_sum # multiple expressions
return grad_score.data
"""
gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
out = ctx.saved_tensors
sds = tuple([out[i] * grad_out[i]
for i in range(len(out))])
accum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(sds)))[0]
out_sddmm = _gsddmm_hetero(gidx, 'mul', e_len, 'e', 'v', tuple(list(out) + list(accum)))
grad_score = tuple([sds[i] - out_sddmm[i]
for i in range(len(sds))])
return (None, None, None) + grad_score
class SegmentReduce(th.autograd.Function): class SegmentReduce(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
...@@ -659,6 +730,9 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t ...@@ -659,6 +730,9 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by) return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def edge_softmax_hetero(gidx, eids=ALL, norm_by='dst', *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits)
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets) return SegmentReduce.apply(op, x, offsets)
......
"""dgl edge_softmax operator module.""" """dgl edge_softmax operator module."""
from ..backend import edge_softmax as edge_softmax_internal from ..backend import edge_softmax as edge_softmax_internal
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal
from ..backend import astype from ..backend import astype
from ..base import ALL, is_all from ..base import ALL, is_all
...@@ -34,8 +35,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -34,8 +35,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph over which edge softmax will be performed. The graph over which edge softmax will be performed.
logits : torch.Tensor logits : torch.Tensor or dict of torch.Tensor
The input edge feature. The input edge feature. Heterogeneous graphs can have dict of tensors where
each tensor stores the edge features of the corresponding relation type.
eids : torch.Tensor or ALL, optional eids : torch.Tensor or ALL, optional
The IDs of the edges to apply edge softmax. If ALL, it will apply edge The IDs of the edges to apply edge softmax. If ALL, it will apply edge
softmax to all edges in the graph. Default: ALL. softmax to all edges in the graph. Default: ALL.
...@@ -44,7 +46,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -44,7 +46,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
Returns Returns
------- -------
Tensor Tensor or tuple of tensors
Softmax value. Softmax value.
Notes Notes
...@@ -55,8 +57,8 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -55,8 +57,8 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
the graph. the graph.
* Return shape: :math:`(E, *, 1)` * Return shape: :math:`(E, *, 1)`
Examples Examples on a homogeneous graph
-------- -------------------------------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> from dgl.nn.functional import edge_softmax >>> from dgl.nn.functional import edge_softmax
...@@ -102,8 +104,45 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -102,8 +104,45 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
[0.5000], [0.5000],
[1.0000], [1.0000],
[0.5000]]) [0.5000]])
Examples on a heterogeneous graph
---------------------------------
Create a heterogeneous graph and initialize its edge features.
>>> hg = dgl.heterograph({
... ('user', 'follows', 'user'): ([0, 0, 1], [0, 1, 2]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1])
... })
>>> edata_follows = th.ones(3, 1).float()
>>> edata_develops = th.ones(2, 1).float()
>>> edata_dict = {('user', 'follows', 'user'): edata_follows,
... ('developer','develops', 'game'): edata_develops}
Apply edge softmax over hg normalized by source nodes:
>>> edge_softmax(hg, edata_dict, norm_by='src')
{('developer', 'develops', 'game'): tensor([[1.],
[1.]]), ('user', 'follows', 'user'): tensor([[0.5000],
[0.5000],
[1.0000]])}
""" """
if not is_all(eids): if not is_all(eids):
eids = astype(eids, graph.idtype) eids = astype(eids, graph.idtype)
return edge_softmax_internal(graph._graph, logits, if graph._graph.number_of_etypes() == 1:
eids=eids, norm_by=norm_by) return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by)
else:
logits_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel]
logits_tuple = tuple(logits_list)
score_tuple = edge_softmax_hetero_internal(graph._graph,
eids, norm_by, *logits_tuple)
score = {}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
score[rel] = score_tuple[etid]
return score
...@@ -456,7 +456,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh ...@@ -456,7 +456,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh
[to_dgl_nd_for_write(out) for out in out_list], [to_dgl_nd_for_write(out) for out in out_list],
lhs_target, rhs_target) lhs_target, rhs_target)
for l in range(gidx.number_of_ntypes()): for l in range(gidx.number_of_etypes()):
# Replace None by empty tensor. Forward func doesn't accept None in tuple. # Replace None by empty tensor. Forward func doesn't accept None in tuple.
e = out_list[l] e = out_list[l]
e = F.tensor([]) if e is None else e e = F.tensor([]) if e is None else e
......
...@@ -159,4 +159,3 @@ if __name__ == '__main__': ...@@ -159,4 +159,3 @@ if __name__ == '__main__':
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
import dgl
from dgl.ops import edge_softmax
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import test_utils
from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst'])
# @pytest.mark.parametrize('shp', edge_softmax_shapes)
@parametrize_dtype
def test_edge_softmax(g, norm_by, idtype):
print("params", norm_by, idtype)
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
F.attach_grad(F.clone(x1))
F.attach_grad(F.clone(x2))
F.attach_grad(F.clone(x3))
F.attach_grad(F.clone(x4))
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
#################################################################
# edge_softmax() on homogeneous graph
#################################################################
with F.record_grad():
hm_g = dgl.to_homogeneous(g)
hm_x = F.cat((x3, x2, x1, x4), 0)
hm_e = F.attach_grad(F.clone(hm_x))
score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
hm_g.edata['score'] = score_hm
ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
r1 = ht_g.edata['score'][('user', 'plays', 'game')]
r2 = ht_g.edata['score'][('user', 'follows', 'user')]
r3 = ht_g.edata['score'][('developer', 'develops', 'game')]
r4 = ht_g.edata['score'][('user', 'wishes', 'game')]
F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
grad_edata_hm = F.grad(hm_e)
#################################################################
# edge_softmax() on heterogeneous graph
#################################################################
e1 = F.attach_grad(F.clone(x1))
e2 = F.attach_grad(F.clone(x2))
e3 = F.attach_grad(F.clone(x3))
e4 = F.attach_grad(F.clone(x4))
e = {('user', 'follows', 'user'): e2,
('user', 'plays', 'game'): e1,
('user', 'wishes', 'game'): e4,
('developer', 'develops', 'game'): e3}
with F.record_grad():
score = edge_softmax(g, e, norm_by=norm_by)
r5 = score[('user', 'plays', 'game')]
r6 = score[('user', 'follows', 'user')]
r7 = score[('developer', 'develops', 'game')]
r8 = score[('user', 'wishes', 'game')]
F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0)
# correctness check
assert F.allclose(r1, r5)
assert F.allclose(r2, r6)
assert F.allclose(r3, r7)
assert F.allclose(r4, r8)
assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__':
test_edge_softmax()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment