Unverified Commit 5eca59d8 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add builtin binary op support in `apply_edges()` for heterogeneous graph (#3598)



* add unittest for binary ops

* Changed loss func
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 88f5a8be
...@@ -4,6 +4,7 @@ from collections import Counter ...@@ -4,6 +4,7 @@ from collections import Counter
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import itertools import itertools
from itertools import product
import backend as F import backend as F
import networkx as nx import networkx as nx
import unittest, pytest import unittest, pytest
...@@ -40,7 +41,7 @@ def create_test_heterograph(idtype): ...@@ -40,7 +41,7 @@ def create_test_heterograph(idtype):
@parametrize_dtype @parametrize_dtype
def test_unary_copy_u(idtype): def test_unary_copy_u(idtype):
def _test(mfunc, rfunc): def _test(mfunc):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
...@@ -53,7 +54,7 @@ def test_unary_copy_u(idtype): ...@@ -53,7 +54,7 @@ def test_unary_copy_u(idtype):
g.nodes['developer'].data['h'] = x2 g.nodes['developer'].data['h'] = x2
################################################################# #################################################################
# apply_edges() is called for each etype in a loop # apply_edges() is called on each relation type separately
################################################################# #################################################################
with F.record_grad(): with F.record_grad():
...@@ -66,7 +67,7 @@ def test_unary_copy_u(idtype): ...@@ -66,7 +67,7 @@ def test_unary_copy_u(idtype):
g.edata['m'].clear() g.edata['m'].clear()
################################################################# #################################################################
# apply_edges() is called for all etypes at once # apply_edges() is called on all relation types
################################################################# #################################################################
g.apply_edges(fn.copy_u('h', 'm')) g.apply_edges(fn.copy_u('h', 'm'))
...@@ -88,16 +89,12 @@ def test_unary_copy_u(idtype): ...@@ -88,16 +89,12 @@ def test_unary_copy_u(idtype):
_print_error(n_grad1, n_grad2) _print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2)) assert(F.allclose(n_grad1, n_grad2))
_test(fn.copy_u, fn.sum) _test(fn.copy_u)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_u', 'max')
# _test('copy_u', 'min')
# _test('copy_u', 'mean')
@parametrize_dtype @parametrize_dtype
def test_unary_copy_e(idtype): def test_unary_copy_e(idtype):
def _test(mfunc, rfunc): def _test(mfunc):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
feat_size = 2 feat_size = 2
...@@ -116,7 +113,7 @@ def test_unary_copy_e(idtype): ...@@ -116,7 +113,7 @@ def test_unary_copy_e(idtype):
g['wishes'].edata['eid'] = x4 g['wishes'].edata['eid'] = x4
################################################################# #################################################################
# apply_edges() is called for each etype in a loop # apply_edges() is called on each relation type separately
################################################################# #################################################################
with F.record_grad(): with F.record_grad():
[g.apply_edges(fn.copy_e('eid', 'm'), etype = rel) [g.apply_edges(fn.copy_e('eid', 'm'), etype = rel)
...@@ -126,7 +123,7 @@ def test_unary_copy_e(idtype): ...@@ -126,7 +123,7 @@ def test_unary_copy_e(idtype):
e_grad1 = F.grad(g['develops'].edata['eid']) e_grad1 = F.grad(g['develops'].edata['eid'])
################################################################# #################################################################
# apply_edges() is called for all etypes at the same time # apply_edges() is called on all relation types
################################################################# #################################################################
g.apply_edges(fn.copy_e('eid', 'm')) g.apply_edges(fn.copy_e('eid', 'm'))
...@@ -148,14 +145,103 @@ def test_unary_copy_e(idtype): ...@@ -148,14 +145,103 @@ def test_unary_copy_e(idtype):
_print_error(e_grad1, e_grad2) _print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2)) assert(F.allclose(e_grad1, e_grad2))
_test(fn.copy_e, fn.sum) _test(fn.copy_e)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_e', 'max')
# _test('copy_e', 'min') @parametrize_dtype
# _test('copy_e', 'mean') def test_binary_op(idtype):
def _test(lhs, rhs, binary_op):
g = create_test_heterograph(idtype)
n1 = F.randn((g.num_nodes('user'), feat_size))
n2 = F.randn((g.num_nodes('developer'), feat_size))
n3 = F.randn((g.num_nodes('game'), feat_size))
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name)
#################################################################
# apply_edges() is called on each relation type separately
#################################################################
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
with F.record_grad():
[g.apply_edges(builtin_msg('h', 'h', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['plays'].edata['m']
loss = F.sum(r1.view(-1), 0)
F.backward(loss)
n_grad1 = F.grad(g.nodes['game'].data['h'])
#################################################################
# apply_edges() is called on all relation types
#################################################################
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
with F.record_grad():
g.apply_edges(builtin_msg('h', 'h', 'm'))
r2 = g['plays'].edata['m']
loss = F.sum(r2.view(-1), 0)
F.backward(loss)
n_grad2 = F.grad(g.nodes['game'].data['h'])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if n_grad1 is not None or n_grad2 is not None:
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div", "dot"]:
print(lhs, rhs, binary_op)
_test(lhs, rhs, binary_op)
if __name__ == '__main__': if __name__ == '__main__':
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment