Unverified Commit 62c827c8 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] A bunch of fixes in edge_softmax_hetero (#4336)



* bunch of fixes

* Update test_edge_softmax_hetero.py

* Update test_edge_softmax_hetero.py
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 5ba5106a
......@@ -135,6 +135,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
eids=eids, norm_by=norm_by)
else:
logits_list = [None] * graph._graph.number_of_etypes()
logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel]
......
......@@ -365,11 +365,11 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
for l, arg_u_nd in enumerate(list_arg_u_nd):
# TODO(Israt): l or src_id as index of lhs
list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
if expand_u and use_cmp:
if list_arg_u[l] is not None and expand_u and use_cmp:
list_arg_u[l] = F.squeeze(list_arg_u[l], -1)
for l, arg_e_nd in enumerate(list_arg_e_nd):
list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
if expand_e and use_cmp:
if list_arg_e[l] is not None and expand_e and use_cmp:
list_arg_e[l] = F.squeeze(list_arg_e[l], -1)
for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):
list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \
......@@ -562,7 +562,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh
e = out_list[l]
e = F.tensor([]) if e is None else e
if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):
e = F.squeeze(v, -1)
e = F.squeeze(e, -1)
out_list[l] = e
out = tuple(out_list)
return out
......
......@@ -2,6 +2,7 @@ import dgl
from dgl.ops import edge_softmax
import dgl.function as fn
from collections import Counter
import math
import numpy as np
import scipy.sparse as ssp
import itertools
......@@ -17,8 +18,6 @@ rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
......@@ -38,7 +37,25 @@ def create_test_heterograph(idtype):
assert g.device == F.ctx()
return g
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_edge_softmax_unidirectional():
g = dgl.heterograph({
('A', 'AB', 'B'): ([1,2,3,1,2,3,1,2,3],[0,0,0,1,1,1,2,2,2]),
('B', 'BB', 'B'): ([0,1,2,0,1,2,0,1,2], [0,0,0,1,1,1,2,2,2])})
g = g.to(F.ctx())
g.edges['AB'].data['x'] = F.ones(9) * 2
g.edges['BB'].data['x'] = F.ones(9)
result = dgl.ops.edge_softmax(g, {'AB': g.edges['AB'].data['x'], 'BB': g.edges['BB'].data['x']})
ab = result['A', 'AB', 'B']
bb = result['B', 'BB', 'B']
e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)
e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)
assert F.allclose(ab, e2)
assert F.allclose(bb, e1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst'])
# @pytest.mark.parametrize('shp', edge_softmax_shapes)
......@@ -109,5 +126,4 @@ def test_edge_softmax(g, norm_by, idtype):
assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__':
test_edge_softmax()
test_edge_softmax_unidirectional()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment