Unverified Commit 4aff59f7 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sparse] Rename existing DGL sparse module. (#5066)

* rename

* next time i should use lintrunner
parent 0159c3c1
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# pylint: disable= invalid-name # pylint: disable= invalid-name
from __future__ import absolute_import from __future__ import absolute_import
from . import backend as F from . import backend as F, ndarray as nd
from . import ndarray as nd
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
...@@ -1036,4 +1035,4 @@ def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn): ...@@ -1036,4 +1035,4 @@ def libra2dgl_set_lr(gdt_key, gdt_value, lrtensor, nc, Nn):
) )
_init_api("dgl.sparse") _init_api("dgl.sparse", __name__)
...@@ -2,9 +2,7 @@ import mxnet as mx ...@@ -2,9 +2,7 @@ import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...base import ALL, dgl_warning, is_all from ..._sparse_ops import (
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp, _bwd_segment_cmp,
_csrmask, _csrmask,
_csrmm, _csrmm,
...@@ -14,6 +12,9 @@ from ...sparse import ( ...@@ -14,6 +12,9 @@ from ...sparse import (
_scatter_add, _scatter_add,
_segment_reduce, _segment_reduce,
) )
from ...base import ALL, dgl_warning, is_all
from ...heterograph_index import create_unitgraph_from_csr
from .tensor import ( from .tensor import (
asnumpy, asnumpy,
context, context,
......
import torch as th import torch as th
from ...base import ALL, is_all from ..._sparse_ops import (
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp, _bwd_segment_cmp,
_csrmask, _csrmask,
_csrmm, _csrmm,
...@@ -22,6 +20,9 @@ from ...sparse import ( ...@@ -22,6 +20,9 @@ from ...sparse import (
_update_grad_minmax_hetero, _update_grad_minmax_hetero,
) )
from ...base import ALL, is_all
from ...heterograph_index import create_unitgraph_from_csr
__all__ = [ __all__ = [
"gspmm", "gspmm",
"gsddmm", "gsddmm",
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
"segment_mm", "segment_mm",
] ]
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
...@@ -125,23 +127,32 @@ def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y): ...@@ -125,23 +127,32 @@ def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
return True return True
return False return False
class empty_context():
class empty_context:
"""Empty context that does nothing""" """Empty context that does nothing"""
def __init__(self, *args, **kargs): def __init__(self, *args, **kargs):
return return
def __enter__(self, *args, **kargs): def __enter__(self, *args, **kargs):
return self return self
def __exit__(self, *args, **kargs): def __exit__(self, *args, **kargs):
return return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops # This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops
autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context
def _cast_if_autocast_enabled(*args): def _cast_if_autocast_enabled(*args):
if not th.is_autocast_enabled() or not th.cuda.is_available(): if not th.is_autocast_enabled() or not th.cuda.is_available():
return args return args
else: else:
return th.cuda.amp.autocast_mode._cast(args, th.get_autocast_gpu_dtype()) return th.cuda.amp.autocast_mode._cast(
args, th.get_autocast_gpu_dtype()
)
class GSpMM(th.autograd.Function): class GSpMM(th.autograd.Function):
@staticmethod @staticmethod
...@@ -1023,7 +1034,9 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"): ...@@ -1023,7 +1034,9 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
if op == "div": if op == "div":
op = "mul" op = "mul"
rhs_data = 1.0 / rhs_data rhs_data = 1.0 / rhs_data
args = _cast_if_autocast_enabled(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) args = _cast_if_autocast_enabled(
gidx, op, lhs_data, rhs_data, lhs_target, rhs_target
)
with autocast(enabled=False): with autocast(enabled=False):
return GSDDMM.apply(*args) return GSDDMM.apply(*args)
...@@ -1052,7 +1065,9 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): ...@@ -1052,7 +1065,9 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
if op in ["add", "mul"]: if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
args = _cast_if_autocast_enabled(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple) args = _cast_if_autocast_enabled(
g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple
)
with autocast(enabled=False): with autocast(enabled=False):
return GSpMM_hetero.apply(*args) return GSpMM_hetero.apply(*args)
...@@ -1083,7 +1098,9 @@ def gsddmm_hetero( ...@@ -1083,7 +1098,9 @@ def gsddmm_hetero(
if op in ["add", "mul"]: if op in ["add", "mul"]:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple)) lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
args = _cast_if_autocast_enabled(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple) args = _cast_if_autocast_enabled(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
)
with autocast(enabled=False): with autocast(enabled=False):
return GSDDMM_hetero.apply(*args) return GSDDMM_hetero.apply(*args)
......
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from ...base import ALL, is_all from ..._sparse_ops import (
from ...heterograph_index import create_unitgraph_from_csr
from ...sparse import (
_bwd_segment_cmp, _bwd_segment_cmp,
_csrmask, _csrmask,
_csrmm, _csrmm,
...@@ -13,6 +11,9 @@ from ...sparse import ( ...@@ -13,6 +11,9 @@ from ...sparse import (
_scatter_add, _scatter_add,
_segment_reduce, _segment_reduce,
) )
from ...base import ALL, is_all
from ...heterograph_index import create_unitgraph_from_csr
from .tensor import asnumpy, context, copy_to, tensor, zerocopy_from_numpy from .tensor import asnumpy, context, copy_to, tensor, zerocopy_from_numpy
__all__ = [ __all__ = [
......
...@@ -25,14 +25,14 @@ import time ...@@ -25,14 +25,14 @@ import time
import torch as th import torch as th
from dgl import DGLGraph from dgl import DGLGraph
from dgl.base import DGLError from dgl._sparse_ops import (
from dgl.data.utils import save_graphs, save_tensors
from dgl.sparse import (
libra2dgl_build_adjlist, libra2dgl_build_adjlist,
libra2dgl_build_dict, libra2dgl_build_dict,
libra2dgl_set_lr, libra2dgl_set_lr,
libra_vertex_cut, libra_vertex_cut,
) )
from dgl.base import DGLError
from dgl.data.utils import save_graphs, save_tensors
def libra_partition(num_community, G, resultdir): def libra_partition(num_community, G, resultdir):
......
import backend as F
import dgl
import numpy as np import numpy as np
import scipy.sparse as ssp
import pytest import pytest
import dgl import scipy.sparse as ssp
from test_utils import parametrize_idtype from test_utils import parametrize_idtype
import backend as F
if F.backend_name == 'pytorch': if F.backend_name == "pytorch":
import torch import torch
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, etype):
def _random_simple_graph(
idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, etype
):
src = np.random.randint(0, M, (max_nnz,)) src = np.random.randint(0, M, (max_nnz,))
dst = np.random.randint(0, N, (max_nnz,)) dst = np.random.randint(0, N, (max_nnz,))
val = np.random.randn(max_nnz) val = np.random.randn(max_nnz)
...@@ -24,33 +28,55 @@ def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, et ...@@ -24,33 +28,55 @@ def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, et
a = ssp.csr_matrix((val, (row, col)), shape=(M, N)) a = ssp.csr_matrix((val, (row, col)), shape=(M, N))
A = dgl.heterograph( A = dgl.heterograph(
{(srctype, etype, dsttype): ( {
F.copy_to(F.tensor(row, dtype=idtype), ctx), (srctype, etype, dsttype): (
F.copy_to(F.tensor(col, dtype=idtype), ctx))}, F.copy_to(F.tensor(row, dtype=idtype), ctx),
num_nodes_dict={srctype: a.shape[0], dsttype: a.shape[1]}) F.copy_to(F.tensor(col, dtype=idtype), ctx),
A.edata['w'] = F.copy_to(F.tensor(val, dtype=dtype), ctx) )
},
num_nodes_dict={srctype: a.shape[0], dsttype: a.shape[1]},
)
A.edata["w"] = F.copy_to(F.tensor(val, dtype=dtype), ctx)
return a, A return a, A
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
def test_csrmm(idtype, dtype): def test_csrmm(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') a, A = _random_simple_graph(
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 600, 700, 9000, 'B', 'C', 'BC') idtype, dtype, F.ctx(), 500, 600, 9000, "A", "B", "AB"
C, C_weights = dgl.sparse._csrmm(A._graph, A.edata['w'], B._graph, B.edata['w'], 2) )
C_adj = C.adjacency_matrix_scipy(0, False, 'csr') b, B = _random_simple_graph(
idtype, dtype, F.ctx(), 600, 700, 9000, "B", "C", "BC"
)
C, C_weights = dgl._sparse_ops._csrmm(
A._graph, A.edata["w"], B._graph, B.edata["w"], 2
)
C_adj = C.adjacency_matrix_scipy(0, False, "csr")
C_adj.data = F.asnumpy(C_weights) C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype) C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a * b).todense(), dtype=dtype) c = F.tensor((a * b).todense(), dtype=dtype)
assert F.allclose(C_adj, c) assert F.allclose(C_adj, c)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
@pytest.mark.parametrize('num_vtypes', [1, 2]) @pytest.mark.parametrize("num_vtypes", [1, 2])
def test_csrmm_backward(idtype, dtype, num_vtypes): def test_csrmm_backward(idtype, dtype, num_vtypes):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 4, 3, 6, 'B', 'A' if num_vtypes == 1 else 'C', 'BA') b, B = _random_simple_graph(
A_row, A_col = A.edges(order='eid') idtype,
B_row, B_col = B.edges(order='eid') dtype,
F.ctx(),
4,
3,
6,
"B",
"A" if num_vtypes == 1 else "C",
"BA",
)
A_row, A_col = A.edges(order="eid")
B_row, B_col = B.edges(order="eid")
A_row = F.asnumpy(A_row) A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col) A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row) B_row = F.asnumpy(B_row)
...@@ -58,49 +84,57 @@ def test_csrmm_backward(idtype, dtype, num_vtypes): ...@@ -58,49 +84,57 @@ def test_csrmm_backward(idtype, dtype, num_vtypes):
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype)) b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w']) A.edata["w"] = F.attach_grad(A.edata["w"])
B.edata['w'] = F.attach_grad(B.edata['w']) B.edata["w"] = F.attach_grad(B.edata["w"])
with F.record_grad(): with F.record_grad():
C = dgl.adj_product_graph(A, B, 'w') C = dgl.adj_product_graph(A, B, "w")
assert len(C.ntypes) == num_vtypes assert len(C.ntypes) == num_vtypes
assert len(C.etypes) == 1 assert len(C.etypes) == 1
C_dense = np.zeros((3, 3)) C_dense = np.zeros((3, 3))
C_row, C_col = C.edges(order='eid') C_row, C_col = C.edges(order="eid")
C_row = F.asnumpy(C_row) C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col) C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) C_dense[C_row, C_col] = F.asnumpy(C.edata["w"])
c_dense = F.matmul(a_dense, b_dense) c_dense = F.matmul(a_dense, b_dense)
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) F.backward(F.reduce_sum(C.edata["w"]) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col] b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) A_spspmm_grad = F.asnumpy(F.grad(A.edata["w"]))
B_spspmm_grad = F.asnumpy(F.grad(B.edata['w'])) B_spspmm_grad = F.asnumpy(F.grad(B.edata["w"]))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4) assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
def test_csrsum(idtype, dtype): def test_csrsum(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') a, A = _random_simple_graph(
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') idtype, dtype, F.ctx(), 500, 600, 9000, "A", "B", "AB"
C, C_weights = dgl.sparse._csrsum([A._graph, B._graph], [A.edata['w'], B.edata['w']]) )
C_adj = C.adjacency_matrix_scipy(0, False, 'csr') b, B = _random_simple_graph(
idtype, dtype, F.ctx(), 500, 600, 9000, "A", "B", "AB"
)
C, C_weights = dgl._sparse_ops._csrsum(
[A._graph, B._graph], [A.edata["w"], B.edata["w"]]
)
C_adj = C.adjacency_matrix_scipy(0, False, "csr")
C_adj.data = F.asnumpy(C_weights) C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype) C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a + b).todense(), dtype=dtype) c = F.tensor((a + b).todense(), dtype=dtype)
assert F.allclose(C_adj, c) assert F.allclose(C_adj, c)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
@pytest.mark.parametrize('nelems', [1, 2]) @pytest.mark.parametrize("nelems", [1, 2])
def test_csrsum_backward(idtype, dtype, nelems): def test_csrsum_backward(idtype, dtype, nelems):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
A_row, A_col = A.edges(order='eid') A_row, A_col = A.edges(order="eid")
B_row, B_col = B.edges(order='eid') B_row, B_col = B.edges(order="eid")
A_row = F.asnumpy(A_row) A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col) A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row) B_row = F.asnumpy(B_row)
...@@ -108,80 +142,97 @@ def test_csrsum_backward(idtype, dtype, nelems): ...@@ -108,80 +142,97 @@ def test_csrsum_backward(idtype, dtype, nelems):
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype)) b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w']) A.edata["w"] = F.attach_grad(A.edata["w"])
B.edata['w'] = F.attach_grad(B.edata['w']) B.edata["w"] = F.attach_grad(B.edata["w"])
with F.record_grad(): with F.record_grad():
if nelems == 2: if nelems == 2:
# Test for two element case # Test for two element case
C = dgl.adj_sum_graph([A, B], 'w') C = dgl.adj_sum_graph([A, B], "w")
assert C.canonical_etypes == A.canonical_etypes assert C.canonical_etypes == A.canonical_etypes
C_dense = np.zeros((3, 4)) C_dense = np.zeros((3, 4))
C_row, C_col = C.edges(order='eid') C_row, C_col = C.edges(order="eid")
C_row = F.asnumpy(C_row) C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col) C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) C_dense[C_row, C_col] = F.asnumpy(C.edata["w"])
c_dense = a_dense + b_dense c_dense = a_dense + b_dense
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) assert np.allclose(
C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4
)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) F.backward(F.reduce_sum(C.edata["w"]) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col] b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) A_spspmm_grad = F.asnumpy(F.grad(A.edata["w"]))
B_spspmm_grad = F.asnumpy(F.grad(B.edata['w'])) B_spspmm_grad = F.asnumpy(F.grad(B.edata["w"]))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) assert np.allclose(
assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4) a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4
)
assert np.allclose(
b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4
)
elif nelems == 1: elif nelems == 1:
# Test for single element case # Test for single element case
C = dgl.adj_sum_graph([A], 'w') C = dgl.adj_sum_graph([A], "w")
assert C.canonical_etypes == A.canonical_etypes assert C.canonical_etypes == A.canonical_etypes
C_dense = np.zeros((3, 4)) C_dense = np.zeros((3, 4))
C_row, C_col = C.edges(order='eid') C_row, C_col = C.edges(order="eid")
C_row = F.asnumpy(C_row) C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col) C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w']) C_dense[C_row, C_col] = F.asnumpy(C.edata["w"])
c_dense = a_dense c_dense = a_dense
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4) assert np.allclose(
C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4
)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense)) F.backward(F.reduce_sum(C.edata["w"]) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) A_spspmm_grad = F.asnumpy(F.grad(A.edata["w"]))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) assert np.allclose(
a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4
)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
@pytest.mark.parametrize('A_nnz', [9000, 0]) @pytest.mark.parametrize("A_nnz", [9000, 0])
@pytest.mark.parametrize('B_nnz', [9000, 0]) @pytest.mark.parametrize("B_nnz", [9000, 0])
def test_csrmask(idtype, dtype, A_nnz, B_nnz): def test_csrmask(idtype, dtype, A_nnz, B_nnz):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, A_nnz, 'A', 'B', 'AB') a, A = _random_simple_graph(
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, B_nnz, 'A', 'B', 'AB') idtype, dtype, F.ctx(), 500, 600, A_nnz, "A", "B", "AB"
C = dgl.sparse._csrmask(A._graph, A.edata['w'], B._graph) )
B_row, B_col = B.edges(order='eid') b, B = _random_simple_graph(
idtype, dtype, F.ctx(), 500, 600, B_nnz, "A", "B", "AB"
)
C = dgl._sparse_ops._csrmask(A._graph, A.edata["w"], B._graph)
B_row, B_col = B.edges(order="eid")
B_row = F.asnumpy(B_row) B_row = F.asnumpy(B_row)
B_col = F.asnumpy(B_col) B_col = F.asnumpy(B_col)
c = F.tensor(a.todense()[B_row, B_col], dtype) c = F.tensor(a.todense()[B_row, B_col], dtype)
assert F.allclose(C, c) assert F.allclose(C, c)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64]) @pytest.mark.parametrize("dtype", [F.float32, F.float64])
def test_csrmask_backward(idtype, dtype): def test_csrmask_backward(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB') b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, "A", "B", "AB")
A_row, A_col = A.edges(order='eid') A_row, A_col = A.edges(order="eid")
B_row, B_col = B.edges(order='eid') B_row, B_col = B.edges(order="eid")
A_row = F.asnumpy(A_row) A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col) A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row) B_row = F.asnumpy(B_row)
B_col = F.asnumpy(B_col) B_col = F.asnumpy(B_col)
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype)) a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w']) A.edata["w"] = F.attach_grad(A.edata["w"])
with F.record_grad(): with F.record_grad():
# Test for two element case # Test for two element case
C1 = F.csrmask(A._graph, A.edata['w'], B._graph) C1 = F.csrmask(A._graph, A.edata["w"], B._graph)
if dgl.backend.backend_name == 'tensorflow': if dgl.backend.backend_name == "tensorflow":
import tensorflow as tf import tensorflow as tf
C2 = tf.gather_nd(a_dense, tf.stack([B_row, B_col], 1)) C2 = tf.gather_nd(a_dense, tf.stack([B_row, B_col], 1))
else: else:
C2 = a_dense[B_row, B_col] C2 = a_dense[B_row, B_col]
...@@ -189,11 +240,11 @@ def test_csrmask_backward(idtype, dtype): ...@@ -189,11 +240,11 @@ def test_csrmask_backward(idtype, dtype):
F.backward(F.reduce_sum(C1) + F.reduce_sum(C2)) F.backward(F.reduce_sum(C1) + F.reduce_sum(C2))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col] a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w'])) A_spspmm_grad = F.asnumpy(F.grad(A.edata["w"]))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4) assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
if __name__ == '__main__': if __name__ == "__main__":
test_csrmm(F.int32, F.float32) test_csrmm(F.int32, F.float32)
test_csrmm(F.int64, F.float32) test_csrmm(F.int64, F.float32)
test_csrsum(F.int32, F.float32) test_csrsum(F.int32, F.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment