import itertools import unittest from collections import Counter from itertools import product import backend as F import dgl import dgl.function as fn import networkx as nx import numpy as np import pytest import scipy.sparse as ssp import test_utils from dgl import DGLError from scipy.sparse import rand from test_utils import get_cases, parametrize_idtype rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean} fill_value = {"sum": 0, "max": float("-inf")} feat_size = 2 @unittest.skipIf( dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now" ) def create_test_heterograph(idtype): # test heterograph from the docstring, plus a user -- wishes -- game relation # 3 users, 2 games, 2 developers # metagraph: # ('user', 'follows', 'user'), # ('user', 'plays', 'game'), # ('user', 'wishes', 'game'), # ('developer', 'develops', 'game')]) g = dgl.heterograph( { ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]), ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]), ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]), ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]), }, idtype=idtype, device=F.ctx(), ) assert g.idtype == idtype assert g.device == F.ctx() return g @parametrize_idtype def test_unary_copy_u(idtype): def _test(mfunc): g = create_test_heterograph(idtype) x1 = F.randn((g.num_nodes("user"), feat_size)) x2 = F.randn((g.num_nodes("developer"), feat_size)) F.attach_grad(x1) F.attach_grad(x2) g.nodes["user"].data["h"] = x1 g.nodes["developer"].data["h"] = x2 ################################################################# # apply_edges() is called on each relation type separately ################################################################# with F.record_grad(): [ g.apply_edges(fn.copy_u("h", "m"), etype=rel) for rel in g.canonical_etypes ] r1 = g["plays"].edata["m"] F.backward(r1, F.ones(r1.shape)) n_grad1 = F.grad(g.ndata["h"]["user"]) # TODO (Israt): clear not working g.edata["m"].clear() ################################################################# # apply_edges() is called on all relation types ################################################################# g.apply_edges(fn.copy_u("h", "m")) r2 = g["plays"].edata["m"] F.backward(r2, F.ones(r2.shape)) n_grad2 = F.grad(g.nodes["user"].data["h"]) # correctness check def _print_error(a, b): for i, (x, y) in enumerate( zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) ): if not np.allclose(x, y): print("@{} {} v.s. {}".format(i, x, y)) if not F.allclose(r1, r2): _print_error(r1, r2) assert F.allclose(r1, r2) if not F.allclose(n_grad1, n_grad2): print("node grad") _print_error(n_grad1, n_grad2) assert F.allclose(n_grad1, n_grad2) _test(fn.copy_u) @parametrize_idtype def test_unary_copy_e(idtype): def _test(mfunc): g = create_test_heterograph(idtype) feat_size = 2 x1 = F.randn((4, feat_size)) x2 = F.randn((4, feat_size)) x3 = F.randn((3, feat_size)) x4 = F.randn((3, feat_size)) F.attach_grad(x1) F.attach_grad(x2) F.attach_grad(x3) F.attach_grad(x4) g["plays"].edata["eid"] = x1 g["follows"].edata["eid"] = x2 g["develops"].edata["eid"] = x3 g["wishes"].edata["eid"] = x4 ################################################################# # apply_edges() is called on each relation type separately ################################################################# with F.record_grad(): [ g.apply_edges(fn.copy_e("eid", "m"), etype=rel) for rel in g.canonical_etypes ] r1 = g["develops"].edata["m"] F.backward(r1, F.ones(r1.shape)) e_grad1 = F.grad(g["develops"].edata["eid"]) ################################################################# # apply_edges() is called on all relation types ################################################################# g.apply_edges(fn.copy_e("eid", "m")) r2 = g["develops"].edata["m"] F.backward(r2, F.ones(r2.shape)) e_grad2 = F.grad(g["develops"].edata["eid"]) # # correctness check def _print_error(a, b): for i, (x, y) in enumerate( zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) ): if not np.allclose(x, y): print("@{} {} v.s. {}".format(i, x, y)) if not F.allclose(r1, r2): _print_error(r1, r2) assert F.allclose(r1, r2) if not F.allclose(e_grad1, e_grad2): print("edge grad") _print_error(e_grad1, e_grad2) assert F.allclose(e_grad1, e_grad2) _test(fn.copy_e) @parametrize_idtype def test_binary_op(idtype): def _test(lhs, rhs, binary_op): g = create_test_heterograph(idtype) n1 = F.randn((g.num_nodes("user"), feat_size)) n2 = F.randn((g.num_nodes("developer"), feat_size)) n3 = F.randn((g.num_nodes("game"), feat_size)) x1 = F.randn((g.num_edges("plays"), feat_size)) x2 = F.randn((g.num_edges("follows"), feat_size)) x3 = F.randn((g.num_edges("develops"), feat_size)) x4 = F.randn((g.num_edges("wishes"), feat_size)) builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs) builtin_msg = getattr(fn, builtin_msg_name) ################################################################# # apply_edges() is called on each relation type separately ################################################################# F.attach_grad(n1) F.attach_grad(n2) F.attach_grad(n3) g.nodes["user"].data["h"] = n1 g.nodes["developer"].data["h"] = n2 g.nodes["game"].data["h"] = n3 F.attach_grad(x1) F.attach_grad(x2) F.attach_grad(x3) F.attach_grad(x4) g["plays"].edata["h"] = x1 g["follows"].edata["h"] = x2 g["develops"].edata["h"] = x3 g["wishes"].edata["h"] = x4 with F.record_grad(): [ g.apply_edges(builtin_msg("h", "h", "m"), etype=rel) for rel in g.canonical_etypes ] r1 = g["plays"].edata["m"] loss = F.sum(r1.view(-1), 0) F.backward(loss) n_grad1 = F.grad(g.nodes["game"].data["h"]) ################################################################# # apply_edges() is called on all relation types ################################################################# F.attach_grad(n1) F.attach_grad(n2) F.attach_grad(n3) g.nodes["user"].data["h"] = n1 g.nodes["developer"].data["h"] = n2 g.nodes["game"].data["h"] = n3 F.attach_grad(x1) F.attach_grad(x2) F.attach_grad(x3) F.attach_grad(x4) g["plays"].edata["h"] = x1 g["follows"].edata["h"] = x2 g["develops"].edata["h"] = x3 g["wishes"].edata["h"] = x4 with F.record_grad(): g.apply_edges(builtin_msg("h", "h", "m")) r2 = g["plays"].edata["m"] loss = F.sum(r2.view(-1), 0) F.backward(loss) n_grad2 = F.grad(g.nodes["game"].data["h"]) # correctness check def _print_error(a, b): for i, (x, y) in enumerate( zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten()) ): if not np.allclose(x, y): print("@{} {} v.s. {}".format(i, x, y)) if not F.allclose(r1, r2): _print_error(r1, r2) assert F.allclose(r1, r2) if n_grad1 is not None or n_grad2 is not None: if not F.allclose(n_grad1, n_grad2): print("node grad") _print_error(n_grad1, n_grad2) assert F.allclose(n_grad1, n_grad2) target = ["u", "v", "e"] for lhs, rhs in product(target, target): if lhs == rhs: continue for binary_op in ["add", "sub", "mul", "div", "dot"]: print(lhs, rhs, binary_op) _test(lhs, rhs, binary_op) if __name__ == "__main__": test_unary_copy_u() test_unary_copy_e()