Unverified Commit 657c220d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Python interface for adjacency matrix summation and multiplication (#2893)

* test commit

* fixes

* oops

* add docs

* lint

* why does it say I have a trailing whitespace

* oh ok

* fixes

* why there's an invalid argument error

* address comments

* fix

* address comments
parent 29fec7d5
...@@ -74,6 +74,8 @@ Operators for generating new graphs by manipulating the structure of the existin ...@@ -74,6 +74,8 @@ Operators for generating new graphs by manipulating the structure of the existin
line_graph line_graph
khop_graph khop_graph
metapath_reachable_graph metapath_reachable_graph
adj_product_graph
adj_sum_graph
.. _api-batch: .. _api-batch:
......
...@@ -1574,6 +1574,91 @@ def scatter_add(x, idx, m): ...@@ -1574,6 +1574,91 @@ def scatter_add(x, idx, m):
""" """
pass pass
def csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Compute weighted adjacency matrix multiplication.
Notes
-----
Both A and B must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
The output unit graph has no format restriction.
Parameters
----------
A : HeteroGraphIndex
The unit graph as left operand.
A_weights : Tensor
The edge weights of A. Must be a 1D vector.
B : HeteroGraphIndex
The unit graph as right operand.
B_weights : Tensor
The edge weights of B. Must be a 1D vector.
num_vtypes : int
The number of node types of the output graph. Must be either 1 or 2.
Returns
-------
HeteroGraphIndex
The output unit graph.
Tensor
The output edge weights.
"""
pass
def csrsum(gidxs, weights):
"""Compute weighted adjacency matrix summation.
Notes
-----
All unit graphs must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
The output unit graph has no format restriction.
Parameters
----------
gidxs : list[HeteroGraphIndex]
The unit graphs.
weights : list[Tensor]
The edge weights of each graph. Must be 1D vectors.
Returns
-------
HeteroGraphIndex
The output unit graph.
Tensor
The output edge weights.
"""
pass
def csrmask(A, A_weights, B):
"""Retrieve the values in the weighted adjacency matrix of graph :attr:`A` at the
non-zero positions of graph :attr:`B`'s adjacency matrix.
In scipy, this is equivalent to ``A[B != 0]``.
Notes
-----
Both A and B must allow creation of CSR representations, and must be simple graphs
(i.e. having at most one edge between two nodes).
Parameters
----------
A : HeteroGraphIndex
The unit graph as left operand.
A_weights : Tensor
The edge weights of A. Must be a 1D vector.
B : HeteroGraphIndex
The unit graph as right operand.
Returns
-------
Tensor
The output tensor.
"""
pass
############################################################################### ###############################################################################
# Other interfaces # Other interfaces
......
...@@ -2,10 +2,13 @@ import mxnet as mx ...@@ -2,10 +2,13 @@ import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...base import dgl_warning, is_all, ALL from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -379,3 +382,88 @@ class ScatterAdd(mx.autograd.Function): ...@@ -379,3 +382,88 @@ class ScatterAdd(mx.autograd.Function):
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
scatter_add_op = ScatterAdd(idx, m) scatter_add_op = ScatterAdd(idx, m)
return scatter_add_op(x) return scatter_add_op(x)
class CSRMM(mx.autograd.Function):
def __init__(self, gidxA, gidxB, num_vtypes):
super().__init__()
self.gidxA = gidxA
self.gidxB = gidxB
self.num_vtypes = num_vtypes
def forward(self, A_weights, B_weights):
gidxC, C_weights = _csrmm(self.gidxA, A_weights, self.gidxB, B_weights, self.num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
self.save_for_backward(A_weights, B_weights)
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxC = self.backward_cache
A_weights, B_weights = self.saved_tensors
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, self.gidxB.reverse(), B_weights, self.gidxA.number_of_ntypes())
dgidxB, dB_weights = _csrmm(
self.gidxA.reverse(), A_weights, gidxC, dC_weights, self.gidxB.number_of_ntypes())
dA_weights = _csrmask(dgidxA, dA_weights, self.gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, self.gidxB)
return dA_weights, dB_weights
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
op = CSRMM(gidxA, gidxB, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(A_weights, B_weights)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
class CSRSum(mx.autograd.Function):
def __init__(self, gidxs):
super().__init__()
self.gidxs = gidxs
def forward(self, *weights):
gidxC, C_weights = _csrsum(self.gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
self.backward_cache = gidxC
nrows = nd.array([nrows], dtype='int64')
ncols = nd.array([ncols], dtype='int64')
return nrows, ncols, C_indptr, C_indices, C_eids, C_weights
def backward(self, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxC = self.backward_cache
return tuple(csrmask(gidxC, dC_weights, gidx) for gidx in self.gidxs)
def csrsum(gidxs, weights):
op = CSRSum(gidxs)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = op(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.asscalar(), ncols.asscalar(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
class CSRMask(mx.autograd.Function):
def __init__(self, gidxA, gidxB):
super().__init__()
self.gidxA = gidxA
self.gidxB = gidxB
def forward(self, A_weights):
return _csrmask(self.gidxA, A_weights, self.gidxB)
def backward(self, dB_weights):
return _csrmask(self.gidxB, dB_weights, self.gidxA)
def csrmask(gidxA, A_weights, gidxB):
op = CSRMask(gidxA, gidxB)
return op(A_weights)
...@@ -2,6 +2,8 @@ import torch as th ...@@ -2,6 +2,8 @@ import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import custom_fwd, custom_bwd from torch.cuda.amp import custom_fwd, custom_bwd
...@@ -24,7 +26,8 @@ else: ...@@ -24,7 +26,8 @@ else:
return bwd(*args, **kwargs) return bwd(*args, **kwargs)
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -303,6 +306,62 @@ class ScatterAdd(th.autograd.Function): ...@@ -303,6 +306,62 @@ class ScatterAdd(th.autograd.Function):
return dy[idx], None, None return dy[idx], None, None
class CSRMM(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxA, gidxB, gidxC
ctx.save_for_backward(A_weights, B_weights)
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights
@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxA, gidxB, gidxC = ctx.backward_cache
A_weights, B_weights = ctx.saved_tensors
dgidxA, dA_weights = csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
dgidxB, dB_weights = csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes())
dA_weights = csrmask(dgidxA, dA_weights, gidxA)
dB_weights = csrmask(dgidxB, dB_weights, gidxB)
return None, dA_weights, None, dB_weights, None
class CSRSum(th.autograd.Function):
@staticmethod
def forward(ctx, gidxs, *weights):
# PyTorch tensors must be explicit arguments of the forward function
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(
0, True, 'csr')
# Note: the returned C_indptr, C_indices and C_eids tensors MUST be the same
# as the underlying tensors of the created graph gidxC.
ctx.backward_cache = gidxs, gidxC
return th.tensor(nrows), th.tensor(ncols), C_indptr, C_indices, C_eids, C_weights
@staticmethod
def backward(ctx, dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
gidxs, gidxC = ctx.backward_cache
return (None,) + tuple(csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
class CSRMask(th.autograd.Function):
@staticmethod
def forward(ctx, gidxA, A_weights, gidxB):
ctx.backward_cache = gidxA, gidxB
return _csrmask(gidxA, A_weights, gidxB)
@staticmethod
def backward(ctx, dB_weights):
gidxA, gidxB = ctx.backward_cache
return None, csrmask(gidxB, dB_weights, gidxA), None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
...@@ -320,3 +379,21 @@ def segment_reduce(op, x, offsets): ...@@ -320,3 +379,21 @@ def segment_reduce(op, x, offsets):
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m) return ScatterAdd.apply(x, idx, m)
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = \
CSRMM.apply(gidxA, A_weights, gidxB, B_weights, num_vtypes)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
def csrsum(gidxs, weights):
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = CSRSum.apply(gidxs, *weights)
gidxC = create_unitgraph_from_csr(
gidxs[0].number_of_ntypes(), nrows.item(), ncols.item(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB)
...@@ -3,8 +3,11 @@ import numpy as np ...@@ -3,8 +3,11 @@ import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -295,3 +298,64 @@ def scatter_add(x, idx, m): ...@@ -295,3 +298,64 @@ def scatter_add(x, idx, m):
def _lambda(x): def _lambda(x):
return scatter_add_real(x, idx, m) return scatter_add_real(x, idx, m)
return _lambda(x) return _lambda(x)
def csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes):
gidxC, C_weights = _csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
dgidxA, dA_weights = _csrmm(
gidxC, dC_weights, gidxB.reverse(), B_weights, gidxA.number_of_ntypes())
dgidxB, dB_weights = _csrmm(
gidxA.reverse(), A_weights, gidxC, dC_weights, gidxB.number_of_ntypes())
dA_weights = _csrmask(dgidxA, dA_weights, gidxA)
dB_weights = _csrmask(dgidxB, dB_weights, gidxB)
return dA_weights, dB_weights
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad
def csrmm(gidxA, A_weights, gidxB, B_weights, num_vtypes):
@tf.custom_gradient
def _lambda(A_weights, B_weights):
return csrmm_real(gidxA, A_weights, gidxB, B_weights, num_vtypes)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(A_weights, B_weights)
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
def csrsum_real(gidxs, weights):
gidxC, C_weights = _csrsum(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids = gidxC.adjacency_matrix_tensors(0, True, 'csr')
def grad(dnrows, dncols, dC_indptr, dC_indices, dC_eids, dC_weights):
# Only the last argument is meaningful.
return tuple(_csrmask(gidxC, dC_weights, gidx) for gidx in gidxs)
return (tf.constant(nrows), tf.constant(ncols), C_indptr, C_indices, C_eids, C_weights), grad
def csrsum(gidxs, weights):
@tf.custom_gradient
def _lambda(*weights):
return csrsum_real(gidxs, weights)
nrows, ncols, C_indptr, C_indices, C_eids, C_weights = _lambda(*weights)
num_vtypes = gidxs[0].number_of_ntypes()
gidxC = create_unitgraph_from_csr(
num_vtypes, nrows.numpy(), ncols.numpy(), C_indptr, C_indices, C_eids,
["coo", "csr", "csc"])
return gidxC, C_weights
def csrmask_real(gidxA, A_weights, gidxB):
B_weights = _csrmask(gidxA, A_weights, gidxB)
def grad(dB_weights):
return _csrmask(gidxB, dB_weights, gidxA)
return B_weights, grad
def csrmask(gidxA, A_weights, gidxB):
@tf.custom_gradient
def _lambda(A_weights):
return csrmask_real(gidxA, A_weights, gidxB)
return _lambda(A_weights)
...@@ -335,30 +335,16 @@ def heterograph(data_dict, ...@@ -335,30 +335,16 @@ def heterograph(data_dict,
' the max ID in the data, but got {} and {}.'.format( ' the max ID in the data, but got {} and {}.'.format(
dty, num_nodes_dict[dty], vrange - 1)) dty, num_nodes_dict[dty], vrange - 1))
# Create the graph # Create the graph
metagraph, ntypes, etypes, relations = heterograph_index.create_metagraph_index(
# Sort the ntypes and relation tuples to have a deterministic order for the same set num_nodes_dict.keys(), node_tensor_dict.keys())
# of type names.
ntypes = list(sorted(num_nodes_dict.keys()))
relations = list(sorted(node_tensor_dict.keys()))
num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64") num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64")
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
meta_edges_src = []
meta_edges_dst = []
etypes = []
rel_graphs = [] rel_graphs = []
for srctype, etype, dsttype in relations: for srctype, etype, dsttype in relations:
meta_edges_src.append(ntype_dict[srctype])
meta_edges_dst.append(ntype_dict[dsttype])
etypes.append(etype)
src, dst = node_tensor_dict[(srctype, etype, dsttype)] src, dst = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(src, dst, srctype, etype, dsttype, g = create_from_edges(src, dst, srctype, etype, dsttype,
num_nodes_dict[srctype], num_nodes_dict[dsttype]) num_nodes_dict[srctype], num_nodes_dict[dsttype])
rel_graphs.append(g) rel_graphs.append(g)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
# create graph index # create graph index
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type) metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
......
...@@ -8,6 +8,7 @@ import scipy ...@@ -8,6 +8,7 @@ import scipy
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from .graph_index import from_coo
from . import backend as F from . import backend as F
from . import utils from . import utils
...@@ -649,6 +650,60 @@ class HeteroGraphIndex(ObjectBase): ...@@ -649,6 +650,60 @@ class HeteroGraphIndex(ObjectBase):
else: else:
raise Exception("unknown format") raise Exception("unknown format")
def adjacency_matrix_tensors(self, etype, transpose, fmt):
"""Return the adjacency matrix as a triplet of tensors.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
etype : int
Edge type
transpose : bool
A flag to transpose the returned adjacency matrix.
fmt : str
Indicates the format of returned adjacency matrix.
Returns
-------
tuple[int, int, Tensor, Tensor] or tuple[int, int, Tensor, Tensor, Tensor]
The number of rows and columns, followed by the adjacency matrix tensors
whose data type and device are the same as those of the graph.
If :attr:`fmt` is ``'coo'``, then the triplet will be
the row array and column array of the COO representation.
If :attr:`fmt` is ``'csr'``, then the triplet will be
the index pointer array (``indptr``), indices array, and data array
of the CSR representation. The data array will contain the edge ID for
each entry of the adjacency matrix. If the data array is empty, then it is
equivalent to a consecutive array from zero to the number of edges minus one.
"""
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt)
srctype, dsttype = self.metagraph.find_edge(etype)
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype)
nnz = self.number_of_edges(etype)
if fmt == "csr":
indptr = F.from_dgl_nd(rst(0))
indices = F.from_dgl_nd(rst(1))
data = F.from_dgl_nd(rst(2))
return nrows, ncols, indptr, indices, data
elif fmt == 'coo':
idx = F.from_dgl_nd(rst(0))
row, col = F.reshape(idx, (2, nnz))
return nrows, ncols, row, col
else:
raise ValueError("unknown format")
def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None): def adjacency_matrix_scipy(self, etype, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
...@@ -674,10 +729,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -674,10 +729,6 @@ class HeteroGraphIndex(ObjectBase):
scipy.sparse.spmatrix scipy.sparse.spmatrix
The scipy representation of adjacency matrix. The scipy representation of adjacency matrix.
""" """
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
if return_edge_ids is None: if return_edge_ids is None:
dgl_warning( dgl_warning(
"Adjacency matrix by default currently returns edge IDs." "Adjacency matrix by default currently returns edge IDs."
...@@ -687,26 +738,30 @@ class HeteroGraphIndex(ObjectBase): ...@@ -687,26 +738,30 @@ class HeteroGraphIndex(ObjectBase):
FutureWarning) FutureWarning)
return_edge_ids = True return_edge_ids = True
rst = _CAPI_DGLHeteroGetAdj(self, int(etype), transpose, fmt) if fmt == 'csr':
srctype, dsttype = self.metagraph.find_edge(etype) nrows, ncols, indptr, indices, data = \
nrows = self.number_of_nodes(srctype) if transpose else self.number_of_nodes(dsttype) self.adjacency_matrix_tensors(etype, transpose, fmt)
ncols = self.number_of_nodes(dsttype) if transpose else self.number_of_nodes(srctype) indptr = F.asnumpy(indptr)
nnz = self.number_of_edges(etype) indices = F.asnumpy(indices)
if fmt == "csr": data = F.asnumpy(data)
indptr = utils.toindex(rst(0), self.dtype).tonumpy()
indices = utils.toindex(rst(1), self.dtype).tonumpy()
data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
# Check if edge ID is omitted # Check if edge ID is omitted
if return_edge_ids and data.shape[0] == 0: if return_edge_ids and data.shape[0] == 0:
data = np.arange(nnz) data = np.arange(self.number_of_edges(etype))
else:
data = np.ones_like(indices)
return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols)) return scipy.sparse.csr_matrix((data, indices, indptr), shape=(nrows, ncols))
elif fmt == 'coo': elif fmt == 'coo':
idx = utils.toindex(rst(0), self.dtype).tonumpy() nrows, ncols, row, col = \
row, col = np.reshape(idx, (2, nnz)) self.adjacency_matrix_tensors(etype, transpose, fmt)
data = np.arange(0, nnz) if return_edge_ids else np.ones_like(row) row = F.asnumpy(row)
col = F.asnumpy(col)
data = np.arange(self.number_of_edges(etype)) if return_edge_ids \
else np.ones_like(row)
return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols)) return scipy.sparse.coo_matrix((data, (row, col)), shape=(nrows, ncols))
else: else:
raise Exception("unknown format") raise ValueError("unknown format")
def incidence_matrix(self, etype, typestr, ctx): def incidence_matrix(self, etype, typestr, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -972,6 +1027,46 @@ class HeteroSubgraphIndex(ObjectBase): ...@@ -972,6 +1027,46 @@ class HeteroSubgraphIndex(ObjectBase):
# Creators # Creators
################################################################# #################################################################
def create_metagraph_index(ntypes, canonical_etypes):
"""Return a GraphIndex instance for a metagraph given the node types and canonical
edge types.
This function will reorder the node types and canonical edge types.
Parameters
----------
ntypes : Iterable[str]
The node types.
canonical_etypes : Iterable[tuple[str, str, str]]
The canonical edge types.
Returns
-------
GraphIndex
The index object for metagraph.
list[str]
The reordered node types for each node in the metagraph.
list[str]
The reordered edge types for each edge in the metagraph.
list[tuple[str, str, str]]
The reordered canonical edge types for each edge in the metagraph.
"""
# Sort the ntypes and relation tuples to have a deterministic order for the same set
# of type names.
ntypes = list(sorted(ntypes))
relations = list(sorted(canonical_etypes))
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
meta_edges_src = []
meta_edges_dst = []
etypes = []
for srctype, etype, dsttype in relations:
meta_edges_src.append(ntype_dict[srctype])
meta_edges_dst.append(ntype_dict[dsttype])
etypes.append(etype)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
return metagraph, ntypes, etypes, relations
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col, def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False): formats, row_sorted=False, col_sorted=False):
"""Create a unitgraph graph index from COO format """Create a unitgraph graph index from COO format
......
...@@ -366,7 +366,7 @@ def _bwd_segment_cmp(feat, arg, m): ...@@ -366,7 +366,7 @@ def _bwd_segment_cmp(feat, arg, m):
to_dgl_nd_for_write(out)) to_dgl_nd_for_write(out))
return out return out
def csrmm(A, A_weights, B, B_weights, num_vtypes): def _csrmm(A, A_weights, B, B_weights, num_vtypes):
"""Return a graph whose adjacency matrix is the sparse matrix multiplication """Return a graph whose adjacency matrix is the sparse matrix multiplication
of those of two given graphs. of those of two given graphs.
...@@ -397,7 +397,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes): ...@@ -397,7 +397,7 @@ def csrmm(A, A_weights, B, B_weights, num_vtypes):
A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes) A, F.to_dgl_nd(A_weights), B, F.to_dgl_nd(B_weights), num_vtypes)
return C, F.from_dgl_nd(C_weights) return C, F.from_dgl_nd(C_weights)
def csrsum(As, A_weights): def _csrsum(As, A_weights):
"""Return a graph whose adjacency matrix is the sparse matrix summation """Return a graph whose adjacency matrix is the sparse matrix summation
of the given list of graphs. of the given list of graphs.
...@@ -421,7 +421,7 @@ def csrsum(As, A_weights): ...@@ -421,7 +421,7 @@ def csrsum(As, A_weights):
C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights]) C, C_weights = _CAPI_DGLCSRSum(As, [F.to_dgl_nd(w) for w in A_weights])
return C, F.from_dgl_nd(C_weights) return C, F.from_dgl_nd(C_weights)
def csrmask(A, A_weights, B): def _csrmask(A, A_weights, B):
"""Return the weights of A at the locations identical to the sparsity pattern """Return the weights of A at the locations identical to the sparsity pattern
of B. of B.
......
...@@ -9,6 +9,7 @@ from ._ffi.function import _init_api ...@@ -9,6 +9,7 @@ from ._ffi.function import _init_api
from .base import dgl_warning, DGLError from .base import dgl_warning, DGLError
from . import convert from . import convert
from .heterograph import DGLHeteroGraph, DGLBlock from .heterograph import DGLHeteroGraph, DGLBlock
from .heterograph_index import create_metagraph_index, create_heterograph_from_relations
from .frame import Frame from .frame import Frame
from . import ndarray as nd from . import ndarray as nd
from . import backend as F from . import backend as F
...@@ -46,7 +47,9 @@ __all__ = [ ...@@ -46,7 +47,9 @@ __all__ = [
'metis_partition_assignment', 'metis_partition_assignment',
'partition_graph_with_halo', 'partition_graph_with_halo',
'metis_partition', 'metis_partition',
'as_heterograph'] 'as_heterograph',
'adj_product_graph',
'adj_sum_graph']
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
...@@ -2223,6 +2226,242 @@ def to_simple(g, ...@@ -2223,6 +2226,242 @@ def to_simple(g,
DGLHeteroGraph.to_simple = utils.alias_func(to_simple) DGLHeteroGraph.to_simple = utils.alias_func(to_simple)
def _unitgraph_less_than_int32(g):
"""Check if a graph with only one edge type has more than 2 ** 31 - 1
nodes or edges.
"""
num_edges = g.num_edges()
num_nodes = max(g.num_nodes(g.ntypes[0]), g.num_nodes(g.ntypes[-1]))
return max(num_nodes, num_edges) <= (1 << 31) - 1
def adj_product_graph(A, B, weight_name, etype='_E'):
r"""Create a weighted graph whose adjacency matrix is the product of
the adjacency matrices of the given two graphs.
Namely, given two weighted graphs :attr:`A` and :attr:`B`, whose rows
represent source nodes and columns represent destination nodes, this function
returns a new graph whose weighted adjacency matrix is
:math:`\mathrm{adj}(A) \times \mathrm{adj}(B)`.
The two graphs must be simple graphs, and must have only one edge type.
Moreover, the number of nodes of the destination node type of :attr:`A` must
be the same as the number of nodes of the source node type of :attr:`B`.
The source node type of the returned graph will be the same as the source
node type of graph :attr:`A`. The destination node type of the returned
graph will be the same as the destination node type of graph :attr:`B`.
If the two node types are the same, the returned graph will be homogeneous.
Otherwise, it will be a bipartite graph.
Unlike ``scipy``, if an edge in the result graph has zero weight, it will
not be removed from the graph.
Notes
-----
This function works on both CPU and GPU. For GPU, the number of nodes and
edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due
to restriction of cuSPARSE.
The edge weights returned by this function is differentiable w.r.t. the
input edge weights.
If the graph format is restricted, both graphs must have CSR available.
Parameters
----------
A : DGLGraph
The graph as left operand.
B : DGLGraph
The graph as right operand.
weight_name : str
The feature name of edge weight of both graphs.
The corresponding edge feature must be scalar.
etype : str, optional
The edge type of the returned graph.
Returns
-------
DGLGraph
The new graph. The edge weight of the returned graph will have the
same feature name as :attr:`weight_name`.
Examples
--------
The following shows weighted adjacency matrix multiplication between two
bipartite graphs. You can also perform this between two homogeneous
graphs, or one homogeneous graph and one bipartite graph, as long as the
numbers of nodes of the same type match.
>>> A = dgl.heterograph({
... ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},
... num_nodes_dict={'A': 3, 'B': 4})
>>> B = dgl.heterograph({
... ('B', 'BA', 'A'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])},
... num_nodes_dict={'A': 3, 'B': 4})
>>> A.edata['w'] = torch.randn(6).requires_grad_()
>>> B.edata['w'] = torch.randn(6).requires_grad_()
If your graph is a multigraph, you will need to call :func:`dgl.to_simple`
to convert it into a simple graph first.
>>> A = dgl.to_simple(A)
>>> B = dgl.to_simple(B)
>>> C = dgl.adj_product_graph(A, B, 'w')
>>> C.edges()
(tensor([0, 0, 1, 2, 2, 2]), tensor([0, 1, 0, 0, 2, 1]))
>>> C.edata['w']
tensor([0.6906, 0.2002, 0.0591, 0.3672, 0.1066, 0.1328],
grad_fn=<CSRMMBackward>)
Note that this function is differentiable:
>>> C.edata['w'].sum().backward()
>>> A.edata['w'].grad
tensor([0.7153, 0.2775, 0.7141, 0.7141, 0.7153, 0.7153])
>>> B.edata['w'].grad
tensor([0.4664, 0.0000, 1.5614, 0.3840, 0.0000, 0.0000])
If the source node type of the left operand is the same as the destination
node type of the right operand, this function returns a homogeneous graph:
>>> C.ntypes
['A']
Otherwise, it returns a bipartite graph instead:
>>> A = dgl.heterograph({
... ('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},
... num_nodes_dict={'A': 3, 'B': 4})
>>> B = dgl.heterograph({
... ('B', 'BC', 'C'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])},
... num_nodes_dict={'C': 3, 'B': 4})
>>> A.edata['w'] = torch.randn(6).requires_grad_()
>>> B.edata['w'] = torch.randn(6).requires_grad_()
>>> C = dgl.adj_product_graph(A, B, 'w')
>>> C.ntypes
['A', 'C']
"""
srctype, _, _ = A.canonical_etypes[0]
_, _, dsttype = B.canonical_etypes[0]
num_vtypes = 1 if srctype == dsttype else 2
ntypes = [srctype] if num_vtypes == 1 else [srctype, dsttype]
if A.device != F.cpu():
if not (_unitgraph_less_than_int32(A) and _unitgraph_less_than_int32(B)):
raise ValueError(
'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.')
C_gidx, C_weights = F.csrmm(
A._graph, A.edata[weight_name], B._graph, B.edata[weight_name], num_vtypes)
num_nodes_dict = {srctype: A.num_nodes(srctype), dsttype: B.num_nodes(dsttype)}
C_metagraph, ntypes, etypes, _ = \
create_metagraph_index(ntypes, [(srctype, etype, dsttype)])
num_nodes_per_type = [num_nodes_dict[ntype] for ntype in ntypes]
C_gidx = create_heterograph_from_relations(
C_metagraph, [C_gidx], utils.toindex(num_nodes_per_type))
C = DGLHeteroGraph(C_gidx, ntypes, etypes)
C.edata[weight_name] = C_weights
return C
def adj_sum_graph(graphs, weight_name):
r"""Create a weighted graph whose adjacency matrix is the sum of the
adjacency matrices of the given graphs, whose rows represent source nodes
and columns represent destination nodes.
All the graphs must be simple graphs, and must have only one edge type.
They also must have the same metagraph, i.e. have the same source node type
and the same destination node type. Moreover, the number of nodes for every
graph must also be the same.
The metagraph of the returned graph will be the same as the input graphs.
Unlike ``scipy``, if an edge in the result graph has zero weight, it will
not be removed from the graph.
Notes
-----
This function works on both CPU and GPU. For GPU, the number of nodes and
edges must be less than the maximum of ``int32`` (i.e. ``2 ** 31 - 1``) due
to restriction of cuSPARSE.
The edge weights returned by this function is differentiable w.r.t. the
input edge weights.
If the graph format is restricted, both graphs must have CSR available.
Parameters
----------
graphs : list[DGLGraph]
The list of graphs. Must have at least one element.
weight_name : str
The feature name of edge weight of both graphs.
The corresponding edge feature must be scalar.
Returns
-------
DGLGraph
The new graph. The edge weight of the returned graph will have the
same feature name as :attr:`weight_name`.
Examples
--------
The following shows weighted adjacency matrix summation between two
bipartite graphs. You can also perform this between homogeneous graphs.
>>> A = dgl.heterograph(
... {('A', 'AB', 'B'): ([2, 2, 0, 2, 0, 1], [2, 1, 0, 0, 2, 2])},
... num_nodes_dict={'A': 3, 'B': 4})
>>> B = dgl.heterograph(
... {('A', 'AB', 'B'): ([1, 2, 0, 2, 1, 0], [0, 3, 2, 1, 3, 3])},
... num_nodes_dict={'A': 3, 'B': 4})
>>> A.edata['w'] = torch.randn(6).requires_grad_()
>>> B.edata['w'] = torch.randn(6).requires_grad_()
If your graph is a multigraph, you will need to call :func:`dgl.to_simple`
to convert it into a simple graph first.
>>> A = dgl.to_simple(A)
>>> B = dgl.to_simple(B)
>>> C = dgl.adj_sum_graph([A, B], 'w')
>>> C.edges()
(tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2]),
tensor([0, 2, 3, 2, 0, 3, 0, 1, 2, 3]))
Note that this function is differentiable:
>>> C.edata['w'].sum().backward()
>>> A.edata['w'].grad
tensor([1., 1., 1., 1., 1., 1.])
>>> B.edata['w'].grad
tensor([1., 1., 1., 1., 1., 1.])
"""
if len(graphs) == 0:
raise ValueError('The list of graphs must not be empty.')
if graphs[0].device != F.cpu():
if not all(_unitgraph_less_than_int32(A) for A in graphs):
raise ValueError(
'For GPU graphs the number of nodes and edges must be less than 2 ** 31 - 1.')
metagraph = graphs[0]._graph.metagraph
num_nodes = utils.toindex(
[graphs[0]._graph.number_of_nodes(i) for i in range(graphs[0]._graph.number_of_ntypes())])
weights = [A.edata[weight_name] for A in graphs]
gidxs = [A._graph for A in graphs]
C_gidx, C_weights = F.csrsum(gidxs, weights)
C_gidx = create_heterograph_from_relations(metagraph, [C_gidx], num_nodes)
C = DGLHeteroGraph(C_gidx, graphs[0].ntypes, graphs[0].etypes)
C.edata[weight_name] = C_weights
return C
def as_heterograph(g, ntype='_U', etype='_E'): # pylint: disable=unused-argument def as_heterograph(g, ntype='_U', etype='_E'): # pylint: disable=unused-argument
"""Convert a DGLGraph to a DGLHeteroGraph with one node and edge type. """Convert a DGLGraph to a DGLHeteroGraph with one node and edge type.
......
...@@ -104,12 +104,19 @@ bool CSRIsSorted(CSRMatrix csr); ...@@ -104,12 +104,19 @@ bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData( runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler); runtime::NDArray weights, DType filler);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
}
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, NullArray(rows->dtype), -1); return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
} }
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
......
...@@ -39,7 +39,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data, ...@@ -39,7 +39,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -56,7 +56,6 @@ NDArray CSRGetData( ...@@ -56,7 +56,6 @@ NDArray CSRGetData(
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr;
const int64_t retlen = std::max(rowlen, collen); const int64_t retlen = std::max(rowlen, collen);
bool return_eids = IsNullArray(weights);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>(); const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids) if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
...@@ -105,19 +104,19 @@ NDArray CSRGetData( ...@@ -105,19 +104,19 @@ NDArray CSRGetData(
} }
template NDArray CSRGetData<kDLCPU, int32_t, float>( template NDArray CSRGetData<kDLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int64_t, float>( template NDArray CSRGetData<kDLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int32_t, double>( template NDArray CSRGetData<kDLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLCPU, int64_t, double>( template NDArray CSRGetData<kDLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>( template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>( template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -127,7 +127,9 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -127,7 +127,9 @@ std::pair<CSRMatrix, NDArray> CSRMM(
B_indptr, B_indices, B_eids, B_data, B_indptr, B_indices, B_eids, B_data,
C_indptr_data, C_indices_data, C_weights_data, M); C_indptr_data, C_indices_data, C_weights_data, M);
return {CSRMatrix(M, P, C_indptr, C_indices), C_weights}; return {
CSRMatrix(M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights};
} }
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>(
......
...@@ -124,7 +124,9 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -124,7 +124,9 @@ std::pair<CSRMatrix, NDArray> CSRSum(
A_indptr, A_indices, A_eids, A_data, A_indptr, A_indices, A_eids, A_data,
C_indptr_data, C_indices_data, C_weights_data, M); C_indptr_data, C_indices_data, C_weights_data, M);
return {CSRMatrix(M, N, C_indptr, C_indices), C_weights}; return {
CSRMatrix(M, N, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights};
} }
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>(
......
...@@ -19,7 +19,7 @@ namespace impl { ...@@ -19,7 +19,7 @@ namespace impl {
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -37,7 +37,6 @@ NDArray CSRGetData( ...@@ -37,7 +37,6 @@ NDArray CSRGetData(
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen); const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
bool return_eids = IsNullArray(weights);
if (return_eids) if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype."; "DType does not match row's dtype.";
...@@ -54,19 +53,19 @@ NDArray CSRGetData( ...@@ -54,19 +53,19 @@ NDArray CSRGetData(
} }
template NDArray CSRGetData<kDLGPU, int32_t, float>( template NDArray CSRGetData<kDLGPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>( template NDArray CSRGetData<kDLGPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>( template NDArray CSRGetData<kDLGPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>( template NDArray CSRGetData<kDLGPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>( template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>( template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -118,7 +118,10 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -118,7 +118,10 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matA));
CUSPARSE_CALL(cusparseDestroySpMat(matB)); CUSPARSE_CALL(cusparseDestroySpMat(matB));
CUSPARSE_CALL(cusparseDestroySpMat(matC)); CUSPARSE_CALL(cusparseDestroySpMat(matC));
return {CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns), dC_weights}; return {
CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
dC_weights};
} }
#else // __CUDACC_VER_MAJOR__ != 11 #else // __CUDACC_VER_MAJOR__ != 11
...@@ -197,7 +200,9 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -197,7 +200,9 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseDestroyMatDescr(matC)); CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
CUSPARSE_CALL(cusparseDestroyMatDescr(matD)); CUSPARSE_CALL(cusparseDestroyMatDescr(matD));
return {CSRMatrix(m, k, C_indptr, C_indices), C_weights}; return {
CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights};
} }
#endif // __CUDACC_VER_MAJOR__ == 11 #endif // __CUDACC_VER_MAJOR__ == 11
...@@ -240,7 +245,8 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -240,7 +245,8 @@ std::pair<CSRMatrix, NDArray> CSRMM(
if (cast) { if (cast) {
CSRMatrix C = result.first; CSRMatrix C = result.first;
return { return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)), CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
AsNumBits(C.data, 64)),
result.second}; result.second};
} else { } else {
return result; return result;
......
...@@ -48,7 +48,7 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -48,7 +48,7 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST); cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
size_t workspace_size = 0; size_t workspace_size = 0;
/* prepare output C */ /* prepare output C */
IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, ctx); IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>(); IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
IdArray dC_columns; IdArray dC_columns;
NDArray dC_weights; NDArray dC_weights;
...@@ -97,7 +97,9 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -97,7 +97,9 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
CUSPARSE_CALL(cusparseDestroyMatDescr(matA)); CUSPARSE_CALL(cusparseDestroyMatDescr(matA));
CUSPARSE_CALL(cusparseDestroyMatDescr(matB)); CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
CUSPARSE_CALL(cusparseDestroyMatDescr(matC)); CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
return {CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns), return {
CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
dC_weights}; dC_weights};
} }
} // namespace cusparse } // namespace cusparse
...@@ -112,22 +114,31 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -112,22 +114,31 @@ std::pair<CSRMatrix, NDArray> CSRSum(
// Cast 64 bit indices to 32 bit // Cast 64 bit indices to 32 bit
std::vector<CSRMatrix> newAs; std::vector<CSRMatrix> newAs;
newAs.reserve(n);
bool cast = false; bool cast = false;
if (As[0].indptr->dtype.bits == 64) { if (As[0].indptr->dtype.bits == 64) {
newAs.reserve(n);
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
newAs.emplace_back( newAs.emplace_back(
As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32), As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32)); AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
cast = true; cast = true;
} else {
for (int i = 0; i < n; ++i)
newAs.push_back(As[i]);
}
// cuSPARSE csrgeam2 requires the CSR to be sorted.
// TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure how to do it.
for (int i = 0; i < n; ++i) {
if (!newAs[i].sorted)
newAs[i] = CSRSort(newAs[i]);
} }
const std::vector<CSRMatrix> &As_ref = cast ? newAs : As;
// Reorder weights if A[i] has edge IDs // Reorder weights if A[i] has edge IDs
std::vector<NDArray> A_weights_reordered(n); std::vector<NDArray> A_weights_reordered(n);
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
if (CSRHasData(As[i])) if (CSRHasData(newAs[i]))
A_weights_reordered[i] = IndexSelect(A_weights[i], As[i].data); A_weights_reordered[i] = IndexSelect(A_weights[i], newAs[i].data);
else else
A_weights_reordered[i] = A_weights[i]; A_weights_reordered[i] = A_weights[i];
} }
...@@ -135,18 +146,20 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -135,18 +146,20 @@ std::pair<CSRMatrix, NDArray> CSRSum(
// Loop and sum // Loop and sum
auto result = std::make_pair( auto result = std::make_pair(
CSRMatrix( CSRMatrix(
As_ref[0].num_rows, As_ref[0].num_cols, newAs[0].num_rows, newAs[0].num_cols,
As_ref[0].indptr, As_ref[0].indices), newAs[0].indptr, newAs[0].indices,
NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data
for (int64_t i = 1; i < n; ++i) for (int64_t i = 1; i < n; ++i)
result = cusparse::CusparseCsrgeam2<DType, int32_t>( result = cusparse::CusparseCsrgeam2<DType, int32_t>(
result.first, result.second, As_ref[i], A_weights_reordered[i]); result.first, result.second, newAs[i], A_weights_reordered[i]);
// Cast 32 bit indices back to 64 bit if necessary // Cast 32 bit indices back to 64 bit if necessary
if (cast) { if (cast) {
CSRMatrix C = result.first; CSRMatrix C = result.first;
return { return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64)), CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
AsNumBits(C.data, 64), true),
result.second}; result.second};
} else { } else {
return result; return result;
......
...@@ -453,9 +453,9 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -453,9 +453,9 @@ class UnitGraph::CSR : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph) { : BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids)); if (aten::IsValidIdArray(edge_ids))
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
<< "indices and edge id arrays should have the same length"; << "edge id arrays should have the same length as indices if not empty";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
......
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import pytest
import dgl import dgl
from utils import parametrize_dtype from utils import parametrize_dtype
import backend as F import backend as F
...@@ -11,51 +12,199 @@ def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, et ...@@ -11,51 +12,199 @@ def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, et
a = ssp.csr_matrix((val, (src, dst)), shape=(M, N)) a = ssp.csr_matrix((val, (src, dst)), shape=(M, N))
a.sum_duplicates() a.sum_duplicates()
a = a.tocoo() a = a.tocoo()
# shuffle edges
perm = np.random.permutation(a.nnz)
row = a.row[perm]
col = a.col[perm]
val = a.data[perm]
a = ssp.csr_matrix((val, (row, col)), shape=(M, N))
A = dgl.heterograph( A = dgl.heterograph(
{('A', 'AB', 'B'): ( {(srctype, etype, dsttype): (
F.copy_to(F.tensor(a.row, dtype=idtype), ctx), F.copy_to(F.tensor(row, dtype=idtype), ctx),
F.copy_to(F.tensor(a.col, dtype=idtype), ctx))}, F.copy_to(F.tensor(col, dtype=idtype), ctx))},
num_nodes_dict={'A': a.shape[0], 'B': a.shape[1]}) num_nodes_dict={srctype: a.shape[0], dsttype: a.shape[1]})
A.edata['w'] = F.copy_to(F.tensor(a.data, dtype=dtype), ctx) A.edata['w'] = F.copy_to(F.tensor(val, dtype=dtype), ctx)
return a, A return a, A
@parametrize_dtype @parametrize_dtype
def test_csrmm(idtype): @pytest.mark.parametrize('dtype', [F.float32, F.float64])
for dtype in [F.float32, F.float64]: def test_csrmm(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB')
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 600, 700, 9000, 'B', 'C', 'BC') b, B = _random_simple_graph(idtype, dtype, F.ctx(), 600, 700, 9000, 'B', 'C', 'BC')
C, C_weights = dgl.sparse.csrmm(A._graph, A.edata['w'], B._graph, B.edata['w'], 2) C, C_weights = dgl.sparse._csrmm(A._graph, A.edata['w'], B._graph, B.edata['w'], 2)
C_adj = C.adjacency_matrix_scipy(0, True, 'csr') C_adj = C.adjacency_matrix_scipy(0, True, 'csr')
C_adj.data = F.asnumpy(C_weights) C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype) C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a * b).todense(), dtype=dtype) c = F.tensor((a * b).todense(), dtype=dtype)
assert F.allclose(C_adj, c) assert F.allclose(C_adj, c)
@parametrize_dtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64])
@pytest.mark.parametrize('num_vtypes', [1, 2])
def test_csrmm_backward(idtype, dtype, num_vtypes):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB')
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 4, 3, 6, 'B', 'A' if num_vtypes == 1 else 'C', 'BA')
A_row, A_col = A.edges(order='eid')
B_row, B_col = B.edges(order='eid')
A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row)
B_col = F.asnumpy(B_col)
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w'])
B.edata['w'] = F.attach_grad(B.edata['w'])
with F.record_grad():
C = dgl.adj_product_graph(A, B, 'w')
assert len(C.ntypes) == num_vtypes
assert len(C.etypes) == 1
C_dense = np.zeros((3, 3))
C_row, C_col = C.edges(order='eid')
C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w'])
c_dense = F.matmul(a_dense, b_dense)
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w']))
B_spspmm_grad = F.asnumpy(F.grad(B.edata['w']))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4)
@parametrize_dtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64])
def test_csrsum(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB')
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB')
C, C_weights = dgl.sparse._csrsum([A._graph, B._graph], [A.edata['w'], B.edata['w']])
C_adj = C.adjacency_matrix_scipy(0, True, 'csr')
C_adj.data = F.asnumpy(C_weights)
C_adj = F.tensor(C_adj.todense(), dtype=dtype)
c = F.tensor((a + b).todense(), dtype=dtype)
assert F.allclose(C_adj, c)
@parametrize_dtype
@pytest.mark.parametrize('dtype', [F.float32, F.float64])
@pytest.mark.parametrize('nelems', [1, 2])
def test_csrsum_backward(idtype, dtype, nelems):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB')
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB')
A_row, A_col = A.edges(order='eid')
B_row, B_col = B.edges(order='eid')
A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row)
B_col = F.asnumpy(B_col)
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
b_dense = F.attach_grad(F.tensor(b.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w'])
B.edata['w'] = F.attach_grad(B.edata['w'])
with F.record_grad():
if nelems == 2:
# Test for two element case
C = dgl.adj_sum_graph([A, B], 'w')
assert C.canonical_etypes == A.canonical_etypes
C_dense = np.zeros((3, 4))
C_row, C_col = C.edges(order='eid')
C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w'])
c_dense = a_dense + b_dense
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
b_dense_grad = F.asnumpy(F.grad(b_dense))[B_row, B_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w']))
B_spspmm_grad = F.asnumpy(F.grad(B.edata['w']))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
assert np.allclose(b_dense_grad, B_spspmm_grad, rtol=1e-4, atol=1e-4)
elif nelems == 1:
# Test for single element case
C = dgl.adj_sum_graph([A], 'w')
assert C.canonical_etypes == A.canonical_etypes
C_dense = np.zeros((3, 4))
C_row, C_col = C.edges(order='eid')
C_row = F.asnumpy(C_row)
C_col = F.asnumpy(C_col)
C_dense[C_row, C_col] = F.asnumpy(C.edata['w'])
c_dense = a_dense
assert np.allclose(C_dense, F.asnumpy(c_dense), rtol=1e-4, atol=1e-4)
F.backward(F.reduce_sum(C.edata['w']) + F.reduce_sum(c_dense))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w']))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
@parametrize_dtype @parametrize_dtype
def test_csrsum(idtype): @pytest.mark.parametrize('dtype', [F.float32, F.float64])
for dtype in [F.float32, F.float64]: @pytest.mark.parametrize('A_nnz', [9000, 0])
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') @pytest.mark.parametrize('B_nnz', [9000, 0])
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') def test_csrmask(idtype, dtype, A_nnz, B_nnz):
C, C_weights = dgl.sparse.csrsum([A._graph, B._graph], [A.edata['w'], B.edata['w']]) a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, A_nnz, 'A', 'B', 'AB')
C_adj = C.adjacency_matrix_scipy(0, True, 'csr') b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, B_nnz, 'A', 'B', 'AB')
C_adj.data = F.asnumpy(C_weights) C = dgl.sparse._csrmask(A._graph, A.edata['w'], B._graph)
C_adj = F.tensor(C_adj.todense(), dtype=dtype) B_row, B_col = B.edges(order='eid')
c = F.tensor((a + b).todense(), dtype=dtype) B_row = F.asnumpy(B_row)
assert F.allclose(C_adj, c) B_col = F.asnumpy(B_col)
c = F.tensor(a.todense()[B_row, B_col], dtype)
assert F.allclose(C, c)
@parametrize_dtype @parametrize_dtype
def test_csrmask(idtype): @pytest.mark.parametrize('dtype', [F.float32, F.float64])
for dtype in [F.float32, F.float64]: def test_csrmask_backward(idtype, dtype):
a, A = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') a, A = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB')
b, B = _random_simple_graph(idtype, dtype, F.ctx(), 500, 600, 9000, 'A', 'B', 'AB') b, B = _random_simple_graph(idtype, dtype, F.ctx(), 3, 4, 6, 'A', 'B', 'AB')
C = dgl.sparse.csrmask(A._graph, A.edata['w'], B._graph) A_row, A_col = A.edges(order='eid')
c = F.tensor(a.tocsr()[b != 0], dtype) B_row, B_col = B.edges(order='eid')
assert F.allclose(C, c) A_row = F.asnumpy(A_row)
A_col = F.asnumpy(A_col)
B_row = F.asnumpy(B_row)
B_col = F.asnumpy(B_col)
a_dense = F.attach_grad(F.tensor(a.todense(), dtype=dtype))
A.edata['w'] = F.attach_grad(A.edata['w'])
with F.record_grad():
# Test for two element case
C1 = F.csrmask(A._graph, A.edata['w'], B._graph)
if dgl.backend.backend_name == 'tensorflow':
import tensorflow as tf
C2 = tf.gather_nd(a_dense, tf.stack([B_row, B_col], 1))
else:
C2 = a_dense[B_row, B_col]
assert F.allclose(C1, C2, rtol=1e-4, atol=1e-4)
F.backward(F.reduce_sum(C1) + F.reduce_sum(C2))
a_dense_grad = F.asnumpy(F.grad(a_dense))[A_row, A_col]
A_spspmm_grad = F.asnumpy(F.grad(A.edata['w']))
assert np.allclose(a_dense_grad, A_spspmm_grad, rtol=1e-4, atol=1e-4)
if __name__ == '__main__': if __name__ == '__main__':
test_csrmm(F.int32) test_csrmm(F.int32, F.float32)
test_csrmm(F.int64) test_csrmm(F.int64, F.float32)
test_csrsum(F.int32) test_csrsum(F.int32, F.float32)
test_csrsum(F.int64) test_csrsum(F.int64, F.float32)
test_csrmask(F.int32) test_csrmask(F.int32, F.float32, 9000, 9000)
test_csrmask(F.int64) test_csrmask(F.int64, F.float32, 9000, 0)
test_csrmask(F.int32, F.float32, 0, 9000)
test_csrmask(F.int64, F.float32, 0, 0)
test_csrmm_backward(F.int32, F.float32, 1)
test_csrmm_backward(F.int64, F.float32, 1)
test_csrmm_backward(F.int32, F.float32, 2)
test_csrmm_backward(F.int64, F.float32, 2)
test_csrsum_backward(F.int32, F.float32, 1)
test_csrsum_backward(F.int64, F.float32, 1)
test_csrsum_backward(F.int32, F.float32, 2)
test_csrsum_backward(F.int64, F.float32, 2)
test_csrmask_backward(F.int32, F.float32)
test_csrmask_backward(F.int64, F.float32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment