"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "571340dac63d3a09e5d66d45244f9f13bb175d00"
Unverified Commit 62c827c8 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] A bunch of fixes in edge_softmax_hetero (#4336)



* bunch of fixes

* Update test_edge_softmax_hetero.py

* Update test_edge_softmax_hetero.py
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 5ba5106a
...@@ -135,6 +135,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'): ...@@ -135,6 +135,7 @@ def edge_softmax(graph, logits, eids=ALL, norm_by='dst'):
eids=eids, norm_by=norm_by) eids=eids, norm_by=norm_by)
else: else:
logits_list = [None] * graph._graph.number_of_etypes() logits_list = [None] * graph._graph.number_of_etypes()
logits = {graph.to_canonical_etype(k): v for k, v in logits.items()}
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
logits_list[etid] = logits[rel] logits_list[etid] = logits[rel]
......
...@@ -365,11 +365,11 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -365,11 +365,11 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
for l, arg_u_nd in enumerate(list_arg_u_nd): for l, arg_u_nd in enumerate(list_arg_u_nd):
# TODO(Israt): l or src_id as index of lhs # TODO(Israt): l or src_id as index of lhs
list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd) list_arg_u[l] = None if list_arg_u[l] is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
if expand_u and use_cmp: if list_arg_u[l] is not None and expand_u and use_cmp:
list_arg_u[l] = F.squeeze(list_arg_u[l], -1) list_arg_u[l] = F.squeeze(list_arg_u[l], -1)
for l, arg_e_nd in enumerate(list_arg_e_nd): for l, arg_e_nd in enumerate(list_arg_e_nd):
list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd) list_arg_e[l] = None if list_arg_e[l] is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
if expand_e and use_cmp: if list_arg_e[l] is not None and expand_e and use_cmp:
list_arg_e[l] = F.squeeze(list_arg_e[l], -1) list_arg_e[l] = F.squeeze(list_arg_e[l], -1)
for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd): for l, arg_u_ntype_nd in enumerate(list_arg_u_ntype_nd):
list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \ list_arg_u_ntype[l] = None if arg_u_ntype_nd is None \
...@@ -562,7 +562,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh ...@@ -562,7 +562,7 @@ def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rh
e = out_list[l] e = out_list[l]
e = F.tensor([]) if e is None else e e = F.tensor([]) if e is None else e
if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs): if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):
e = F.squeeze(v, -1) e = F.squeeze(e, -1)
out_list[l] = e out_list[l] = e
out = tuple(out_list) out = tuple(out_list)
return out return out
......
...@@ -2,6 +2,7 @@ import dgl ...@@ -2,6 +2,7 @@ import dgl
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
import dgl.function as fn import dgl.function as fn
from collections import Counter from collections import Counter
import math
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import itertools import itertools
...@@ -17,8 +18,6 @@ rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean} ...@@ -17,8 +18,6 @@ rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")} fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2 feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def create_test_heterograph(idtype): def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers # 3 users, 2 games, 2 developers
...@@ -37,8 +36,26 @@ def create_test_heterograph(idtype): ...@@ -37,8 +36,26 @@ def create_test_heterograph(idtype):
assert g.idtype == idtype assert g.idtype == idtype
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_edge_softmax_unidirectional():
g = dgl.heterograph({
('A', 'AB', 'B'): ([1,2,3,1,2,3,1,2,3],[0,0,0,1,1,1,2,2,2]),
('B', 'BB', 'B'): ([0,1,2,0,1,2,0,1,2], [0,0,0,1,1,1,2,2,2])})
g = g.to(F.ctx())
g.edges['AB'].data['x'] = F.ones(9) * 2
g.edges['BB'].data['x'] = F.ones(9)
result = dgl.ops.edge_softmax(g, {'AB': g.edges['AB'].data['x'], 'BB': g.edges['BB'].data['x']})
ab = result['A', 'AB', 'B']
bb = result['B', 'BB', 'B']
e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)
e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)
assert F.allclose(ab, e2)
assert F.allclose(bb, e1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize('g', get_cases(['clique'])) @pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst']) @pytest.mark.parametrize('norm_by', ['src', 'dst'])
# @pytest.mark.parametrize('shp', edge_softmax_shapes) # @pytest.mark.parametrize('shp', edge_softmax_shapes)
...@@ -109,5 +126,4 @@ def test_edge_softmax(g, norm_by, idtype): ...@@ -109,5 +126,4 @@ def test_edge_softmax(g, norm_by, idtype):
assert F.allclose(grad_edata_hm, grad_edata_ht) assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__': if __name__ == '__main__':
test_edge_softmax() test_edge_softmax_unidirectional()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment