Unverified Commit 2b98e764 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Transform] Modules for Augmentation (#3668)



* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Fix lint

* lint

* Update

* Update

* lint fix

* Fix CI

* Fix

* Fix CI

* Update

* Fix

* Update

* Update

* Augmentation (#10)

* Update

* PPR

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* CI

* lint

* lint

* Update

* Update

* Fix AddEdge

* try import

* Update

* Fix

* CI
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-31-136.us-west-2.compute.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ba62b730
...@@ -13,6 +13,12 @@ BaseTransform ...@@ -13,6 +13,12 @@ BaseTransform
:members: __call__, __repr__ :members: __call__, __repr__
:show-inheritance: :show-inheritance:
Compose
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Compose
:show-inheritance:
AddSelfLoop AddSelfLoop
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -55,8 +61,50 @@ AddMetaPaths ...@@ -55,8 +61,50 @@ AddMetaPaths
.. autoclass:: AddMetaPaths .. autoclass:: AddMetaPaths
:show-inheritance: :show-inheritance:
KNNGraph GCNNorm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GCNNorm
:show-inheritance:
PPR
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PPR
:show-inheritance:
HeatKernel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: HeatKernel
:show-inheritance:
GDC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GDC
:show-inheritance:
NodeShuffle
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: NodeShuffle
:show-inheritance:
DropNode
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DropNode
:show-inheritance:
DropEdge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DropEdge
:show-inheritance:
AddEdge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: KNNGraph .. autoclass:: AddEdge
:show-inheritance: :show-inheritance:
...@@ -570,6 +570,21 @@ def exp(input): ...@@ -570,6 +570,21 @@ def exp(input):
""" """
pass pass
def inverse(input):
"""Returns the inverse matrix of a square matrix if it exists.
Parameters
----------
input : Tensor
The input square matrix.
Returns
-------
Tensor
The output tensor.
"""
pass
def sqrt(input): def sqrt(input):
"""Returns a new tensor with the square root of the elements of the input tensor `input`. """Returns a new tensor with the square root of the elements of the input tensor `input`.
...@@ -1057,6 +1072,21 @@ def equal(x, y): ...@@ -1057,6 +1072,21 @@ def equal(x, y):
""" """
pass pass
def allclose(x, y, rtol=1e-4, atol=1e-4):
"""Compares whether all elements are close.
Parameters
----------
x : Tensor
First tensor
y : Tensor
Second tensor
rtol : float, optional
Relative tolerance
atol : float, optional
Absolute tolerance
"""
def logical_not(input): def logical_not(input):
"""Perform a logical not operation. Equivalent to np.logical_not """Perform a logical not operation. Equivalent to np.logical_not
......
...@@ -191,6 +191,9 @@ def argsort(input, dim, descending): ...@@ -191,6 +191,9 @@ def argsort(input, dim, descending):
def exp(input): def exp(input):
return nd.exp(input) return nd.exp(input)
def inverse(input):
return nd.linalg_inverse(input)
def sqrt(input): def sqrt(input):
return nd.sqrt(input) return nd.sqrt(input)
...@@ -327,6 +330,9 @@ def boolean_mask(input, mask): ...@@ -327,6 +330,9 @@ def boolean_mask(input, mask):
def equal(x, y): def equal(x, y):
return x == y return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)
def logical_not(input): def logical_not(input):
return nd.logical_not(input) return nd.logical_not(input)
......
...@@ -14,8 +14,8 @@ from ..._deprecate import kernel as K ...@@ -14,8 +14,8 @@ from ..._deprecate import kernel as K
from ...function.base import TargetCode from ...function.base import TargetCode
from ...base import dgl_warning from ...base import dgl_warning
if LooseVersion(th.__version__) < LooseVersion("1.5.0"): if LooseVersion(th.__version__) < LooseVersion("1.8.0"):
raise Exception("Detected an old version of PyTorch. Please update torch>=1.5.0 " raise Exception("Detected an old version of PyTorch. Please update torch>=1.8.0 "
"for the best experience.") "for the best experience.")
def data_type_dict(): def data_type_dict():
...@@ -164,6 +164,9 @@ def argtopk(input, k, dim, descending=True): ...@@ -164,6 +164,9 @@ def argtopk(input, k, dim, descending=True):
def exp(input): def exp(input):
return th.exp(input) return th.exp(input)
def inverse(input):
return th.inverse(input)
def sqrt(input): def sqrt(input):
return th.sqrt(input) return th.sqrt(input)
...@@ -276,6 +279,9 @@ def boolean_mask(input, mask): ...@@ -276,6 +279,9 @@ def boolean_mask(input, mask):
def equal(x, y): def equal(x, y):
return x == y return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4):
return th.allclose(x, y, rtol=rtol, atol=atol)
def logical_not(input): def logical_not(input):
return ~input return ~input
......
...@@ -244,6 +244,10 @@ def exp(input): ...@@ -244,6 +244,10 @@ def exp(input):
return tf.exp(input) return tf.exp(input)
def inverse(input):
return tf.linalg.inv(input)
def sqrt(input): def sqrt(input):
return tf.sqrt(input) return tf.sqrt(input)
...@@ -396,6 +400,11 @@ def equal(x, y): ...@@ -396,6 +400,11 @@ def equal(x, y):
return x == y return x == y
def allclose(x, y, rtol=1e-4, atol=1e-4):
return np.allclose(tf.convert_to_tensor(x).numpy(),
tf.convert_to_tensor(y).numpy(), rtol=rtol, atol=atol)
def logical_not(input): def logical_not(input):
return ~input return ~input
......
This diff is collapsed.
...@@ -345,8 +345,8 @@ def test_empty_data_initialized(): ...@@ -345,8 +345,8 @@ def test_empty_data_initialized():
assert len(g.ndata["ha"]) == 1 assert len(g.ndata["ha"]) == 1
def test_is_sorted(): def test_is_sorted():
u_src, u_dst = edge_pair_input(False) u_src, u_dst = edge_pair_input(False)
s_src, s_dst = edge_pair_input(True) s_src, s_dst = edge_pair_input(True)
u_src = F.tensor(u_src, dtype=F.int32) u_src = F.tensor(u_src, dtype=F.int32)
u_dst = F.tensor(u_dst, dtype=F.int32) u_dst = F.tensor(u_dst, dtype=F.int32)
...@@ -409,7 +409,7 @@ def test_formats(): ...@@ -409,7 +409,7 @@ def test_formats():
fail = False fail = False
finally: finally:
assert not fail assert not fail
if __name__ == '__main__': if __name__ == '__main__':
test_query() test_query()
test_mutation() test_mutation()
......
...@@ -23,6 +23,7 @@ import dgl.function as fn ...@@ -23,6 +23,7 @@ import dgl.function as fn
import dgl.partition import dgl.partition
import backend as F import backend as F
import unittest import unittest
import math
from utils import parametrize_dtype from utils import parametrize_dtype
from test_heterograph import create_test_heterograph3, create_test_heterograph4, create_test_heterograph5 from test_heterograph import create_test_heterograph3, create_test_heterograph4, create_test_heterograph5
...@@ -2156,5 +2157,144 @@ def test_module_compose(idtype): ...@@ -2156,5 +2157,144 @@ def test_module_compose(idtype):
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)} assert eset == {(0, 1), (1, 2), (1, 0), (2, 1), (0, 0), (1, 1), (2, 2)}
@parametrize_dtype
def test_module_gcnnorm(idtype):
g = dgl.heterograph({
('A', 'r1', 'A'): ([0, 1, 2], [0, 0, 1]),
('A', 'r2', 'B'): ([0, 0], [1, 1]),
('B', 'r3', 'B'): ([0, 1, 2], [0, 0, 1])
}, idtype=idtype, device=F.ctx())
g.edges['r3'].data['w'] = F.tensor([0.1, 0.2, 0.3])
transform = dgl.GCNNorm()
new_g = transform(g)
assert 'w' not in new_g.edges[('A', 'r2', 'B')].data
assert F.allclose(new_g.edges[('A', 'r1', 'A')].data['w'],
F.tensor([1./2, 1./math.sqrt(2), 0.]))
assert F.allclose(new_g.edges[('B', 'r3', 'B')].data['w'], F.tensor([1./3, 2./3, 0.]))
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_ppr(idtype):
g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((6, 2))
transform = dgl.PPR(avg_degree=2)
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.num_nodes() == g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2),
(2, 3), (2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)}
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
assert 'w' in new_g.edata
# Prior edge weights
g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
new_g = transform(g)
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (1, 3), (2, 2), (2, 3), (2, 4),
(3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)}
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_heat_kernel(idtype):
# Case1: directed graph
g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((6, 2))
transform = dgl.HeatKernel(avg_degree=1)
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.num_nodes() == g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 2), (0, 4), (1, 3), (1, 5), (2, 3), (2, 4), (3, 5), (4, 5)}
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
assert 'w' in new_g.edata
# Case2: weighted undirected graph
g = dgl.graph(([0, 1, 2, 3], [1, 0, 3, 2]), idtype=idtype, device=F.ctx())
g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4])
new_g = transform(g)
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)}
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_gdc(idtype):
transform = dgl.GDC([0.1, 0.2, 0.1], avg_degree=1)
g = dgl.graph(([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.randn((6, 2))
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.num_nodes() == g.num_nodes()
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (0, 2), (0, 4), (1, 1), (1, 3), (1, 5), (2, 2), (2, 3),
(2, 4), (3, 3), (3, 5), (4, 3), (4, 4), (4, 5), (5, 5)}
assert F.allclose(g.ndata['h'], new_g.ndata['h'])
assert 'w' in new_g.edata
# Prior edge weights
g.edata['w'] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5])
new_g = transform(g)
src, dst = new_g.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == {(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4), (5, 5)}
@parametrize_dtype
def test_module_node_shuffle(idtype):
transform = dgl.NodeShuffle()
g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx())
new_g = transform(g)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_drop_node(idtype):
transform = dgl.DropNode()
g = dgl.heterograph({
('A', 'r', 'B'): ([0, 1], [1, 2]),
}, idtype=idtype, device=F.ctx())
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_dtype
def test_module_drop_edge(idtype):
transform = dgl.DropEdge()
g = dgl.heterograph({
('A', 'r1', 'B'): ([0, 1], [1, 2]),
('C', 'r2', 'C'): ([3, 4, 5], [6, 7, 8])
}, idtype=idtype, device=F.ctx())
new_g = transform(g)
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
@parametrize_dtype
def test_module_add_edge(idtype):
transform = dgl.AddEdge()
g = dgl.heterograph({
('A', 'r1', 'B'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5]),
('C', 'r2', 'C'): ([0, 1, 2, 3, 4], [1, 2, 3, 4, 5])
}, idtype=idtype, device=F.ctx())
new_g = transform(g)
assert new_g.num_edges(('A', 'r1', 'B')) == 6
assert new_g.num_edges(('C', 'r2', 'C')) == 6
assert new_g.idtype == g.idtype
assert new_g.device == g.device
assert new_g.ntypes == g.ntypes
assert new_g.canonical_etypes == g.canonical_etypes
if __name__ == '__main__': if __name__ == '__main__':
test_partition_with_halo() test_partition_with_halo()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment