Unverified Commit 88f5a8be authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add heterogeneous graph API for `edge_softmax` (#3571)



* edge_softmax_hetero forwar+cpu+norm=dst

* convert eids to list

* addedunittest

* added unittest

* added backward. Not tested correctness

* minor

* changed reducer to max from sum

* bugfix

* docstring

* add GPU unittest

* output converted to dict from tuple

* lint check
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 25538ba4
......@@ -1625,6 +1625,43 @@ def edge_softmax(gidx, logits, eids, norm_by):
"""
pass
def edge_softmax_hetero(gidx, eids, norm_by, *logits):
r"""Compute edge softmax.
For a node :math:`i`, edge softmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
By default edge softmax is normalized by destination nodes(i.e. :math:`ij`
are incoming edges of `i` in the formula above). We also support edge
softmax normalized by source nodes(i.e. :math:`ij` are outgoing edges of
`i` in the formula). The previous case correspond to softmax in GAT and
Transformer, and the later case correspond to softmax in Capsule network.
Parameters
----------
gidx : HeteroGraphIndex
The graph to perfor edge softmax on.
eids : dict of tensors
Each tensor has the edges on which to apply edge softmax for a
corresponsing relation type.
logits : tuple of tensors
The input edge features of different relation types.
norm_by : str, could be `src` or `dst`
Normalized by source nodes or destination nodes. Default: `dst`.
Returns
-------
Tensor
Softmax value
"""
pass
def segment_reduce(op, x, offsets):
"""Segment reduction operator.
......
......@@ -26,8 +26,8 @@ else:
return bwd(*args, **kwargs)
return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
__all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'edge_softmax_hetero',
'segment_reduce', 'scatter_add', 'csrmm', 'csrsum', 'csrmask']
def _reduce_grad(grad, shape):
......@@ -501,10 +501,81 @@ class EdgeSoftmax(th.autograd.Function):
out, = ctx.saved_tensors
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
return None, grad_score, None, None
class EdgeSoftmax_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, gidx, eids, norm_by, *score):
"""Forward function.
Pseudo-code:
.. code:: python
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
# remember to save the graph to backward cache before making it
# a local variable
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
feats = tuple(lhs + list(score))
score_max = _gspmm_hetero(gidx, 'copy_rhs', 'max', u_len, feats)[0]
out_tmp = _gsddmm_hetero(gidx, 'sub', e_len, 'e', 'v', tuple(list(score) + list(score_max)))
score = tuple([th.exp(out_tmp[i]) if out_tmp[i] is not None else None
for i in range(len(out_tmp))])
score_sum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(score)))[0]
out = _gsddmm_hetero(gidx, 'div', e_len, 'e', 'v', tuple(list(score) + list(score_sum)))
ctx.backward_cache = gidx
ctx.save_for_backward(*out)
return out
@staticmethod
@custom_bwd
def backward(ctx, *grad_out):
"""Backward function.
Pseudo-code:
.. code:: python
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - out * sds_sum # multiple expressions
return grad_score.data
"""
gidx = ctx.backward_cache
# See https://github.com/dmlc/dgl/pull/3386
ctx.backward_cache = None
u_len = gidx.number_of_ntypes()
e_len = gidx.number_of_etypes()
lhs = [None] * u_len
out = ctx.saved_tensors
sds = tuple([out[i] * grad_out[i]
for i in range(len(out))])
accum = _gspmm_hetero(gidx, 'copy_rhs', 'sum', u_len, tuple(lhs + list(sds)))[0]
out_sddmm = _gsddmm_hetero(gidx, 'mul', e_len, 'e', 'v', tuple(list(out) + list(accum)))
grad_score = tuple([sds[i] - out_sddmm[i]
for i in range(len(sds))])
return (None, None, None) + grad_score
class SegmentReduce(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
......@@ -659,6 +730,9 @@ def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_t
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def edge_softmax_hetero(gidx, eids=ALL, norm_by='dst', *logits):
return EdgeSoftmax_hetero.apply(gidx, eids, norm_by, *logits)
def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets)
......
"""dgl edge_softmax operator module."""
from ..backend import edge_softmax as edge_softmax_internal
from ..backend import edge_softmax_hetero as edge_softmax_hetero_internal
from ..backend import astype
from ..base import ALL, is_all
......@@ -34,8 +35,9 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
----------
graph : DGLGraph
The graph over which edge softmax will be performed.
logits : torch.Tensor
The input edge feature.
logits : torch.Tensor or dict of torch.Tensor
The input edge feature. Heterogeneous graphs can have dict of tensors where
each tensor stores the edge features of the corresponding relation type.
eids : torch.Tensor or ALL, optional
The IDs of the edges to apply edge softmax. If ALL, it will apply edge
softmax to all edges in the graph. Default: ALL.
......@@ -44,7 +46,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
Returns
-------
Tensor
Tensor or tuple of tensors
Softmax value.
Notes
......@@ -55,8 +57,8 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
the graph.
* Return shape: :math:`(E, *, 1)`
Examples
--------
Examples on a homogeneous graph
-------------------------------
The following example uses PyTorch backend.
>>> from dgl.nn.functional import edge_softmax
......@@ -102,8 +104,45 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
[0.5000],
[1.0000],
[0.5000]])
Examples on a heterogeneous graph
---------------------------------
Create a heterogeneous graph and initialize its edge features.
>>> hg = dgl.heterograph({
... ('user', 'follows', 'user'): ([0, 0, 1], [0, 1, 2]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1])
... })
>>> edata_follows = th.ones(3, 1).float()
>>> edata_develops = th.ones(2, 1).float()
>>> edata_dict = {('user', 'follows', 'user'): edata_follows,
... ('developer','develops', 'game'): edata_develops}
Apply edge softmax over hg normalized by source nodes:
>>> edge_softmax(hg, edata_dict, norm_by='src')
{('developer', 'develops', 'game'): tensor([[1.],
[1.]]), ('user', 'follows', 'user'): tensor([[0.5000],
[0.5000],
[1.0000]])}
"""
if not is_all(eids):
eids = astype(eids, graph.idtype)
if graph._graph.number_of_etypes() == 1:
return edge_softmax_internal(graph._graph, logits,
eids=eids, norm_by=norm_by)
else:
logits_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel]
logits_tuple = tuple(logits_list)
score_tuple = edge_softmax_hetero_internal(graph._graph,
eids, norm_by, *logits_tuple)
score = {}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
score[rel] = score_tuple[etid]
return score
......@@ -456,7 +456,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh
[to_dgl_nd_for_write(out) for out in out_list],
lhs_target, rhs_target)
for l in range(gidx.number_of_ntypes()):
for l in range(gidx.number_of_etypes()):
# Replace None by empty tensor. Forward func doesn't accept None in tuple.
e = out_list[l]
e = F.tensor([]) if e is None else e
......
......@@ -159,4 +159,3 @@ if __name__ == '__main__':
test_unary_copy_u()
test_unary_copy_e()
import dgl
from dgl.ops import edge_softmax
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import test_utils
from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst'])
# @pytest.mark.parametrize('shp', edge_softmax_shapes)
@parametrize_dtype
def test_edge_softmax(g, norm_by, idtype):
print("params", norm_by, idtype)
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
F.attach_grad(F.clone(x1))
F.attach_grad(F.clone(x2))
F.attach_grad(F.clone(x3))
F.attach_grad(F.clone(x4))
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
#################################################################
# edge_softmax() on homogeneous graph
#################################################################
with F.record_grad():
hm_g = dgl.to_homogeneous(g)
hm_x = F.cat((x3, x2, x1, x4), 0)
hm_e = F.attach_grad(F.clone(hm_x))
score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
hm_g.edata['score'] = score_hm
ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
r1 = ht_g.edata['score'][('user', 'plays', 'game')]
r2 = ht_g.edata['score'][('user', 'follows', 'user')]
r3 = ht_g.edata['score'][('developer', 'develops', 'game')]
r4 = ht_g.edata['score'][('user', 'wishes', 'game')]
F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
grad_edata_hm = F.grad(hm_e)
#################################################################
# edge_softmax() on heterogeneous graph
#################################################################
e1 = F.attach_grad(F.clone(x1))
e2 = F.attach_grad(F.clone(x2))
e3 = F.attach_grad(F.clone(x3))
e4 = F.attach_grad(F.clone(x4))
e = {('user', 'follows', 'user'): e2,
('user', 'plays', 'game'): e1,
('user', 'wishes', 'game'): e4,
('developer', 'develops', 'game'): e3}
with F.record_grad():
score = edge_softmax(g, e, norm_by=norm_by)
r5 = score[('user', 'plays', 'game')]
r6 = score[('user', 'follows', 'user')]
r7 = score[('developer', 'develops', 'game')]
r8 = score[('user', 'wishes', 'game')]
F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0)
# correctness check
assert F.allclose(r1, r5)
assert F.allclose(r2, r6)
assert F.allclose(r3, r7)
assert F.allclose(r4, r8)
assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__':
test_edge_softmax()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment