"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8b85ca5d326465a4344decc63653cb539a1b2f66"
Unverified Commit 5eca59d8 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add builtin binary op support in `apply_edges()` for heterogeneous graph (#3598)



* add unittest for binary ops

* Changed loss func
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 88f5a8be
......@@ -4,6 +4,7 @@ from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
from itertools import product
import backend as F
import networkx as nx
import unittest, pytest
......@@ -40,7 +41,7 @@ def create_test_heterograph(idtype):
@parametrize_dtype
def test_unary_copy_u(idtype):
def _test(mfunc, rfunc):
def _test(mfunc):
g = create_test_heterograph(idtype)
......@@ -53,7 +54,7 @@ def test_unary_copy_u(idtype):
g.nodes['developer'].data['h'] = x2
#################################################################
# apply_edges() is called for each etype in a loop
# apply_edges() is called on each relation type separately
#################################################################
with F.record_grad():
......@@ -66,7 +67,7 @@ def test_unary_copy_u(idtype):
g.edata['m'].clear()
#################################################################
# apply_edges() is called for all etypes at once
# apply_edges() is called on all relation types
#################################################################
g.apply_edges(fn.copy_u('h', 'm'))
......@@ -88,16 +89,12 @@ def test_unary_copy_u(idtype):
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
_test(fn.copy_u, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_u', 'max')
# _test('copy_u', 'min')
# _test('copy_u', 'mean')
_test(fn.copy_u)
@parametrize_dtype
def test_unary_copy_e(idtype):
def _test(mfunc, rfunc):
def _test(mfunc):
g = create_test_heterograph(idtype)
feat_size = 2
......@@ -116,7 +113,7 @@ def test_unary_copy_e(idtype):
g['wishes'].edata['eid'] = x4
#################################################################
# apply_edges() is called for each etype in a loop
# apply_edges() is called on each relation type separately
#################################################################
with F.record_grad():
[g.apply_edges(fn.copy_e('eid', 'm'), etype = rel)
......@@ -126,7 +123,7 @@ def test_unary_copy_e(idtype):
e_grad1 = F.grad(g['develops'].edata['eid'])
#################################################################
# apply_edges() is called for all etypes at the same time
# apply_edges() is called on all relation types
#################################################################
g.apply_edges(fn.copy_e('eid', 'm'))
......@@ -148,14 +145,103 @@ def test_unary_copy_e(idtype):
_print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2))
_test(fn.copy_e, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_e', 'max')
# _test('copy_e', 'min')
# _test('copy_e', 'mean')
_test(fn.copy_e)
@parametrize_dtype
def test_binary_op(idtype):
def _test(lhs, rhs, binary_op):
g = create_test_heterograph(idtype)
n1 = F.randn((g.num_nodes('user'), feat_size))
n2 = F.randn((g.num_nodes('developer'), feat_size))
n3 = F.randn((g.num_nodes('game'), feat_size))
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name)
#################################################################
# apply_edges() is called on each relation type separately
#################################################################
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
with F.record_grad():
[g.apply_edges(builtin_msg('h', 'h', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['plays'].edata['m']
loss = F.sum(r1.view(-1), 0)
F.backward(loss)
n_grad1 = F.grad(g.nodes['game'].data['h'])
#################################################################
# apply_edges() is called on all relation types
#################################################################
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
with F.record_grad():
g.apply_edges(builtin_msg('h', 'h', 'm'))
r2 = g['plays'].edata['m']
loss = F.sum(r2.view(-1), 0)
F.backward(loss)
n_grad2 = F.grad(g.nodes['game'].data['h'])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if n_grad1 is not None or n_grad2 is not None:
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div", "dot"]:
print(lhs, rhs, binary_op)
_test(lhs, rhs, binary_op)
if __name__ == '__main__':
test_unary_copy_u()
test_unary_copy_e()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment