"tests/python/vscode:/vscode.git/clone" did not exist on "c51516a83ec35e3377b27885328d498aa950c032"
Unverified Commit 188152b8 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Add Heterograph support on Python for builtin unary msg functions...


[Feature] Add Heterograph support on Python for builtin unary msg functions (copy_u, copy_e) (#2989)

* heterograph for binary func

* Added SDDMM support

* Added unittest

* added binary test cases

* unary mfuncs works

* Fixed lint err

* lint check and others

* link check

* fixed import *_hetero issue

* lint check

* replace torch with dgl backend

* lint cehck

* removed torch from test

* skip mxnet unittest

* skip gpu test

* Remove unused/duplicated code

* minor

* changed data structure of ndata and edata

* link check

* reorganized

* minor lint

* minor lint

* raise error for udf func

* lint check

* fix for CUDA 10.1

* add a note for future cross-type max/min reducing

* Add support CUDA < 11

* lint check

* tidied C code

* remove dummy GSDDMM_hetero backward implementation
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent 4556fab8
...@@ -1452,6 +1452,45 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -1452,6 +1452,45 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
""" """
pass pass
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
All the relation types of the heterogeneous graph will be processed together.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
Note that this function does not handle gradients.
Parameters
----------
g : HeteroGraph
The input graph.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``,
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_and_rhs_tuple : tuple of tensors
lhs_data and rhs_data are concatenated to one tuple. lhs_data is
also a tuple of tensors of size number of ntypes. Same is true for
rhs_data.
The tensor(s) in the tuple could be None
Returns
-------
tuple of tensor
The resulting tuple of tensor.
"""
pass
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
It computes edge features by :attr:`op` lhs features and rhs features. It computes edge features by :attr:`op` lhs features and rhs features.
...@@ -1487,6 +1526,44 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -1487,6 +1526,44 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
""" """
pass pass
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on
heterogenenous graph. All the relation types of the heterogeneous graph
will be processed together.
It computes edge features by :attr:`op` lhs features and rhs features.
.. math::
x_{e} = \phi(x_{lhs}, x_{rhs}), \forall (u,e,v)\in \mathcal{G}
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
Parameters
----------
gidx : HeteroGraphIndex
The input graph.
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for right operand.
lhs_and_rhs_tuple : tuple of tensors
lhs_data and rhs_data are concatenated to one tuple. lhs_data is
also a tuple of tensors of size number of ntypes. Same is true for
rhs_data.
The tensor(s) in the tuple could be None
Returns
-------
tuple of tensor
The resulting tuple of tensor.
"""
pass
def edge_softmax(gidx, logits, eids, norm_by): def edge_softmax(gidx, logits, eids, norm_by):
r"""Compute edge softmax. r"""Compute edge softmax.
......
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -26,7 +26,7 @@ else: ...@@ -26,7 +26,7 @@ else:
return bwd(*args, **kwargs) return bwd(*args, **kwargs)
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add', __all__ = ['gspmm', 'gsddmm', 'gspmm_hetero', 'gsddmm_hetero', 'edge_softmax', 'segment_reduce', 'scatter_add',
'csrmm', 'csrsum', 'csrmask'] 'csrmm', 'csrsum', 'csrmask']
...@@ -145,6 +145,49 @@ class GSpMM(th.autograd.Function): ...@@ -145,6 +145,49 @@ class GSpMM(th.autograd.Function):
return None, None, None, dX, dY return None, None, None, dX, dY
class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, reduce_op, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, feats)
ctx.backward_cache = g, op, reduce_op
ctx.save_for_backward(*feats, argX, argY)
return out
@staticmethod
@custom_bwd
def backward(ctx, *dZ):
g, op, reduce_op = ctx.backward_cache
feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1]
num_ntypes = g._graph.number_of_ntypes()
X, Y = feats[:num_ntypes], feats[num_ntypes:]
if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = g.reverse()
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op == 'copy_lhs':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', *dZ)
dX = tuple([_reduce_grad(dX[i], X[i].shape) if X[i] is not None else None
for i in range(len(X))])
else: # X has not gradient
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op in ['copy_rhs']:
tmp_Z = tuple([_addsub(op, dZ[i]) if dZ[i] is not None else None
for i in range(len(dZ))])
tmp = tuple(X + tmp_Z)
dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp)
dY = tuple([_reduce_grad(dY[i], Y[i].shape) if Y[i] is not None else None
for i in range(len(Y))])
else: # Y has no gradient
dY = tuple([None] * len(Y))
return (None, None, None) + dX + dY
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
...@@ -206,6 +249,22 @@ class GSDDMM(th.autograd.Function): ...@@ -206,6 +249,22 @@ class GSDDMM(th.autograd.Function):
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
class GSDDMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, lhs_target, rhs_target, feats)
ctx.backward_cache = g, op, lhs_target, rhs_target
ctx.save_for_backward(*feats)
return out
@staticmethod
@custom_bwd
# TODO(Israt): Implement the backward operator
def backward(ctx, *dZ):
raise NotImplementedError('Homogenized GSDDMM backward operation is not implemented.')
class EdgeSoftmax(th.autograd.Function): class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
...@@ -365,15 +424,18 @@ class CSRMask(th.autograd.Function): ...@@ -365,15 +424,18 @@ class CSRMask(th.autograd.Function):
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
return GSpMM_hetero.apply(g, op, reduce_op, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
return GSDDMM_hetero.apply(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by) return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets) return SegmentReduce.apply(op, x, offsets)
......
...@@ -184,6 +184,40 @@ def _bucketing(val): ...@@ -184,6 +184,40 @@ def _bucketing(val):
return bkts return bkts
return unique_val, bucketor return unique_val, bucketor
def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None):
"""Get node or edge feature data of the given name for all the types.
Parameters
-------------
graph : DGLGraph
The input graph.
data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]]
Node or edge data stored in DGLGraph. The key of the dictionary
is the node type name or edge type name.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_list : list[tensor] or list[None]
The feature on source nodes, could be list of None if op is ``copy_rhs``.
rhs_list : list[tensor] or list[None]
The feature on edges, could be list of None if op is ``copy_lhs``.
Returns
--------
data_tuple : tuple(Tensor)
Feature data stored in tuple of tensors. The i^th tensor stores the feature
data of type ``types[i]``.
"""
if op == "copy_u":
for srctype, _, _ in graph.canonical_etypes:
src_id = graph.get_ntype_id(srctype)
lhs_list[src_id] = data_dict[srctype]
elif op == "copy_e":
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
rhs_list[etid] = data_dict[rel]
return tuple(lhs_list + rhs_list)
def invoke_gsddmm(graph, func): def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph. """Invoke g-SDDMM computation on the graph.
...@@ -255,6 +289,11 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None) ...@@ -255,6 +289,11 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
else: else:
x = alldata[mfunc.target][mfunc.in_field] x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name)) op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1:
# Convert to list as dict is unordered.
lhs_list = [None] * graph._graph.number_of_ntypes()
rhs_list = [None] * graph._graph.number_of_etypes()
x = data_dict_to_tuple(graph, x, mfunc.name, lhs_list, rhs_list)
z = op(graph, x) z = op(graph, x)
return {rfunc.out_field : z} return {rfunc.out_field : z}
......
...@@ -4682,6 +4682,10 @@ class DGLHeteroGraph(object): ...@@ -4682,6 +4682,10 @@ class DGLHeteroGraph(object):
"""Send messages along all the edges of the specified type """Send messages along all the edges of the specified type
and update all the nodes of the corresponding destination type. and update all the nodes of the corresponding destination type.
For heterogeneous graphs with number of relation types > 1, send messages
along all the edges, reduce them by type-wisely and across different types
at the same time. Then, update the node features of all the nodes.
Parameters Parameters
---------- ----------
message_func : dgl.function.BuiltinFunction or callable message_func : dgl.function.BuiltinFunction or callable
...@@ -4743,13 +4747,58 @@ class DGLHeteroGraph(object): ...@@ -4743,13 +4747,58 @@ class DGLHeteroGraph(object):
tensor([[0.], tensor([[0.],
[0.], [0.],
[3.]]) [3.]])
**Heterogenenous graph (number relation types > 1)**
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): ([0, 1], [1, 1]),
... ('game', 'attracts', 'user'): ([0], [1])
... })
Update all.
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
>>> g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
>>> g.nodes['user'].data['h']
tensor([[0.],
[4.]])
""" """
etid = self.get_etype_id(etype) # Graph with one relation type
etype = self.canonical_etypes[etid] if self._graph.number_of_etypes() == 1 or etype is not None:
_, dtid = self._graph.metagraph.find_edge(etid) etid = self.get_etype_id(etype)
g = self if etype is None else self[etype] etype = self.canonical_etypes[etid]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func) _, dtid = self._graph.metagraph.find_edge(etid)
self._set_n_repr(dtid, ALL, ndata) g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, ALL, ndata)
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(message_func) or not core.is_builtin(reduce_func):
raise DGLError("User defined functions are not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['max', 'min']:
raise NotImplementedError("Reduce op \'" + reduce_func.name + "\' is not yet "
"supported in update_all for heterogeneous graphs. "
"Please use multi_update_all instead.")
if reduce_func.name in ['mean']:
raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead.")
if message_func.name not in ['copy_u', 'copy_e']:
raise NotImplementedError("Op \'" + message_func.name + "\' is not yet supported"
"in update_all for heterogeneous graphs. Please use"
"multi_update_all instead.")
g = self
all_out = core.message_passing(g, message_func, reduce_func, apply_node_func)
key = list(all_out.keys())[0]
out_tensor_tuples = all_out[key]
dst_tensor = {}
for _, _, dsttype in g.canonical_etypes:
dtid = g.get_ntype_id(dsttype)
dst_tensor[key] = out_tensor_tuples[dtid]
self._node_frames[dtid].update(dst_tensor)
################################################################# #################################################################
# Message passing on heterograph # Message passing on heterograph
...@@ -4848,6 +4897,7 @@ class DGLHeteroGraph(object): ...@@ -4848,6 +4897,7 @@ class DGLHeteroGraph(object):
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid]) self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])
################################################################# #################################################################
# Message propagation # Message propagation
################################################################# #################################################################
...@@ -5645,6 +5695,7 @@ class DGLHeteroGraph(object): ...@@ -5645,6 +5695,7 @@ class DGLHeteroGraph(object):
gidx = self._graph.shared_memory(name, self.ntypes, self.etypes, formats) gidx = self._graph.shared_memory(name, self.ntypes, self.etypes, formats)
return DGLHeteroGraph(gidx, self.ntypes, self.etypes) return DGLHeteroGraph(gidx, self.ntypes, self.etypes)
def long(self): def long(self):
"""Cast the graph to one with idtype int64 """Cast the graph to one with idtype int64
......
...@@ -3,10 +3,35 @@ from itertools import product ...@@ -3,10 +3,35 @@ from itertools import product
import sys import sys
from ..backend import gsddmm as gsddmm_internal from ..backend import gsddmm as gsddmm_internal
from ..backend import gsddmm_hetero as gsddmm_internal_hetero
from .. import backend as F from .. import backend as F
__all__ = ['gsddmm', 'copy_u', 'copy_v', 'copy_e'] __all__ = ['gsddmm', 'copy_u', 'copy_v', 'copy_e']
def reshape_lhs_rhs(lhs_data, rhs_data):
r""" Expand dims so that there will be no broadcasting issues with different
number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4)
that are valid broadcastable shapes, change them to (N, 1, 3, 1) and
(E, 5, 3, 4)
Parameters
----------
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
"""
lhs_shape = F.shape(lhs_data)
rhs_shape = F.shape(rhs_data)
if len(lhs_shape) != len(rhs_shape):
max_ndims = max(len(lhs_shape), len(rhs_shape))
lhs_pad_ndims = max_ndims - len(lhs_shape)
rhs_pad_ndims = max_ndims - len(rhs_shape)
new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:]
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape)
return lhs_data, rhs_data
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
...@@ -43,24 +68,28 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -43,24 +68,28 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
tensor tensor
The result tensor. The result tensor.
""" """
if op not in ['copy_lhs', 'copy_rhs']: if g._graph.number_of_etypes() == 1:
# Expand dims so that there will be no broadcasting issues with different if op not in ['copy_lhs', 'copy_rhs']:
# number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4) lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)
# that are valid broadcastable shapes, change them to (N, 1, 3, 1) and return gsddmm_internal(
# (E, 5, 3, 4) g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
lhs_shape = F.shape(lhs_data) else:
rhs_shape = F.shape(rhs_data) lhs_data_dict = lhs_data
if len(lhs_shape) != len(rhs_shape): rhs_data_dict = rhs_data
max_ndims = max(len(lhs_shape), len(rhs_shape)) lhs_list = [None] * g._graph.number_of_ntypes()
lhs_pad_ndims = max_ndims - len(lhs_shape) rhs_list = [None] * g._graph.number_of_ntypes()
rhs_pad_ndims = max_ndims - len(rhs_shape) for srctype, _, dsttype in g.canonical_etypes:
new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:] src_id = g.get_ntype_id(srctype)
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:] dst_id = g.get_ntype_id(dsttype)
lhs_data = F.reshape(lhs_data, new_lhs_shape) lhs_data = lhs_data_dict[srctype]
rhs_data = F.reshape(rhs_data, new_rhs_shape) rhs_data = rhs_data_dict[dsttype]
return gsddmm_internal( if op not in ['copy_lhs', 'copy_rhs']:
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target) lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)
lhs_list[src_id] = lhs_data
rhs_list[dst_id] = rhs_data
lhs_and_rhs_tuple = tuple(lhs_list + rhs_list)
# With max and min reducers infinity will be returned for zero degree nodes
return gsddmm_internal_hetero(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op): def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target) name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
......
...@@ -2,10 +2,35 @@ ...@@ -2,10 +2,35 @@
import sys import sys
from ..backend import gspmm as gspmm_internal from ..backend import gspmm as gspmm_internal
from ..backend import gspmm_hetero as gspmm_internal_hetero
from .. import backend as F from .. import backend as F
__all__ = ['gspmm'] __all__ = ['gspmm']
def reshape_lhs_rhs(lhs_data, rhs_data):
r""" Expand dims so that there will be no broadcasting issues with different
number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4)
that are valid broadcastable shapes, change them to (N, 1, 3, 1) and
(E, 5, 3, 4)
Parameters
----------
lhs_data : tensor or None
The left operand, could be None if it's not required by op.
rhs_data : tensor or None
The right operand, could be None if it's not required by op.
"""
lhs_shape = F.shape(lhs_data)
rhs_shape = F.shape(rhs_data)
if len(lhs_shape) != len(rhs_shape):
max_ndims = max(len(lhs_shape), len(rhs_shape))
lhs_pad_ndims = max_ndims - len(lhs_shape)
rhs_pad_ndims = max_ndims - len(rhs_shape)
new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:]
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape)
return lhs_data, rhs_data
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(g, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface. r""" Generalized Sparse Matrix Multiplication interface.
...@@ -43,28 +68,24 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -43,28 +68,24 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
tensor tensor
The result tensor. The result tensor.
""" """
if op not in ['copy_lhs', 'copy_rhs']: if g._graph.number_of_etypes() == 1:
# Expand dims so that there will be no broadcasting issues with different if op not in ['copy_lhs', 'copy_rhs']:
# number of dimensions. For example, given two shapes (N, 3, 1), (E, 5, 3, 4) lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)
# that are valid broadcastable shapes, change them to (N, 1, 3, 1) and # With max and min reducers infinity will be returned for zero degree nodes
# (E, 5, 3, 4) ret = gspmm_internal(g._graph, op,
lhs_shape = F.shape(lhs_data) 'sum' if reduce_op == 'mean' else reduce_op,
rhs_shape = F.shape(rhs_data) lhs_data, rhs_data)
if len(lhs_shape) != len(rhs_shape): # Replace infinity with zero for isolated nodes when reducer is min/max
max_ndims = max(len(lhs_shape), len(rhs_shape)) if reduce_op in ['min', 'max']:
lhs_pad_ndims = max_ndims - len(lhs_shape) ret = F.replace_inf_with_zero(ret)
rhs_pad_ndims = max_ndims - len(rhs_shape) else:
new_lhs_shape = (lhs_shape[0],) + (1,) * lhs_pad_ndims + lhs_shape[1:] if op in ['copy_lhs', 'copy_rhs']:
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:] lhs_and_rhs_tuple = lhs_data if rhs_data is None else rhs_data
lhs_data = F.reshape(lhs_data, new_lhs_shape) ret = gspmm_internal_hetero(g, op,
rhs_data = F.reshape(rhs_data, new_rhs_shape) 'sum' if reduce_op == 'mean' else reduce_op,
# With max and min reducers infinity will be returned for zero degree nodes *lhs_and_rhs_tuple)
ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op, # TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
lhs_data, rhs_data)
# Replace infinity with zero for isolated nodes when reducer is min/max
if reduce_op in ['min', 'max']:
ret = F.replace_inf_with_zero(ret)
# divide in degrees for mean reducer. # divide in degrees for mean reducer.
if reduce_op == 'mean': if reduce_op == 'mean':
......
...@@ -179,6 +179,85 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -179,6 +179,85 @@ def _gspmm(gidx, op, reduce_op, u, e):
return v, (arg_u, arg_e) return v, (arg_u, arg_e)
def _gspmm_hetero(g, op, reduce_op, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface.
"""
num_ntypes = g._graph.number_of_ntypes()
u_tuple, e_tuple = u_and_e_tuple[:num_ntypes], u_and_e_tuple[num_ntypes:]
gidx = g._graph
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
expand_u, expand_e = False, False
list_u = [None] * gidx.number_of_ntypes()
list_v = [None] * gidx.number_of_ntypes()
list_e = [None] * gidx.number_of_etypes()
for rel in g.canonical_etypes:
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
u = u_tuple[src_id] if use_u else None
e = e_tuple[etid] if use_e else None
if use_u:
if u is not None and F.ndim(u) == 1:
u = F.unsqueeze(u, -1)
expand_u = True
list_u[src_id] = u if use_u else None
if use_e:
if e is not None and F.ndim(e) == 1:
e = F.unsqueeze(e, -1)
expand_e = True
list_e[etid] = e if use_e else None
ctx = F.context(u) if use_u else F.context(e) # TODO(Israt): Put outside of loop
dtype = F.dtype(u) if use_u else F.dtype(e) # TODO(Israt): Put outside of loop
u_shp = F.shape(u) if use_u else (0,)
e_shp = F.shape(e) if use_e else (0,)
v_shp = (gidx.number_of_nodes(dst_id), ) +\
infer_broadcast_shape(op, u_shp[1:], e_shp[1:])
list_v[dst_id] = F.zeros(v_shp, dtype, ctx)
use_cmp = reduce_op in ['max', 'min']
arg_u, arg_e = None, None
idtype = getattr(F, gidx.dtype)
if use_cmp:
if use_u:
arg_u = F.zeros(v_shp, idtype, ctx)
if use_e:
arg_e = F.zeros(v_shp, idtype, ctx)
arg_u_nd = to_dgl_nd_for_write(arg_u)
arg_e_nd = to_dgl_nd_for_write(arg_e)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSpMMHetero(gidx, op, reduce_op,
[to_dgl_nd(u_i) for u_i in list_u],
[to_dgl_nd(e_i) for e_i in list_e],
[to_dgl_nd_for_write(v_i) for v_i in list_v],
arg_u_nd,
arg_e_nd)
arg_u = None if arg_u is None else F.zerocopy_from_dgl_ndarray(arg_u_nd)
arg_e = None if arg_e is None else F.zerocopy_from_dgl_ndarray(arg_e_nd)
# To deal with scalar node/edge features.
for l in range(gidx.number_of_ntypes()):
# replace None by empty tensor. Forward func doesn't accept None in tuple.
v = list_v[l]
v = F.tensor([]) if v is None else v
if ((expand_u or not use_u) and (expand_e or not use_e)):
v = F.squeeze(v, -1) # To deal with scalar node/edge features.
list_v[l] = v
out = tuple(list_v)
if expand_u and use_cmp:
arg_u = F.squeeze(arg_u, -1)
if expand_e and use_cmp:
arg_e = F.squeeze(arg_e, -1)
return out, (arg_u, arg_e)
def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. It
takes the result of :attr:`op` on source node feature and destination node takes the result of :attr:`op` on source node feature and destination node
...@@ -239,6 +318,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -239,6 +318,7 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
expand_rhs = True expand_rhs = True
lhs_target = target_mapping[lhs_target] lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target] rhs_target = target_mapping[rhs_target]
ctx = F.context(lhs) if use_lhs else F.context(rhs) ctx = F.context(lhs) if use_lhs else F.context(rhs)
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs) dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,) lhs_shp = F.shape(lhs) if use_lhs else (0,)
...@@ -257,6 +337,70 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -257,6 +337,70 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out return out
def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
"""
num_ntypes = g._graph.number_of_ntypes()
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:num_ntypes], lhs_and_rhs_tuple[num_ntypes:]
gidx = g._graph
use_lhs = op != 'copy_rhs'
use_rhs = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
expand_lhs, expand_rhs = False, False
lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target]
lhs_list = [None] * gidx.number_of_ntypes()
rhs_list = [None] * gidx.number_of_ntypes()
out_list = [None] * gidx.number_of_etypes()
for rel in g.canonical_etypes:
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
lhs = lhs_tuple[src_id]
rhs = rhs_tuple[dst_id]
if use_lhs:
if lhs is not None and F.ndim(lhs) == 1:
lhs = F.unsqueeze(lhs, -1)
expand_lhs = True
if use_rhs:
if rhs is not None and F.ndim(rhs) == 1:
rhs = F.unsqueeze(lhs, -1)
expand_rhs = True
ctx = F.context(lhs) if use_lhs else F.context(rhs)
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,)
rhs_shp = F.shape(rhs) if use_rhs else (0,)
lhs_list[src_id] = lhs if use_lhs else None
rhs_list[dst_id] = rhs if use_rhs else None
out_shp = (gidx.number_of_edges(etid), ) +\
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:])
out_list[etid] = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(gidx, op,
[to_dgl_nd(lhs) for lhs in lhs_list],
[to_dgl_nd(rhs) for rhs in rhs_list],
[to_dgl_nd_for_write(out) for out in out_list],
lhs_target, rhs_target)
for l in range(gidx.number_of_ntypes()):
# Replace None by empty tensor. Forward func doesn't accept None in tuple.
e = out_list[l]
e = F.tensor([]) if e is None else e
if (expand_lhs or not use_lhs) and (expand_rhs or not use_rhs):
e = F.squeeze(v, -1)
out_list[l] = e
out = tuple(out_list)
return out
def _segment_reduce(op, feat, offsets): def _segment_reduce(op, feat, offsets):
r"""Segment reduction operator. r"""Segment reduction operator.
......
...@@ -362,9 +362,6 @@ void CusparseCsrmm2Hetero( ...@@ -362,9 +362,6 @@ void CusparseCsrmm2Hetero(
CUSPARSE_CALL(cusparseDestroyDnMat(matB)); CUSPARSE_CALL(cusparseDestroyDnMat(matB));
CUSPARSE_CALL(cusparseDestroyDnMat(matC)); CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else #else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr; cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr)); CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
...@@ -378,11 +375,8 @@ void CusparseCsrmm2Hetero( ...@@ -378,11 +375,8 @@ void CusparseCsrmm2Hetero(
descr, (valptr)? valptr : A_data, descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data), static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data), static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m)); B_data, n, &beta, C_data, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
// transpose the output matrix
_Transpose(trans_out, C_data, n, m);
device->FreeWorkspace(ctx, trans_out);
#endif #endif
if (valptr) if (valptr)
device->FreeWorkspace(ctx, valptr); device->FreeWorkspace(ctx, valptr);
...@@ -521,54 +515,85 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -521,54 +515,85 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<NDArray>& out_aux, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id const std::vector<dgl_type_t>& out_ntids) { // output node type id
int64_t feat_len = bcast.out_len;
bool is_scalar_efeat = vec_efeat.size() != 0; bool is_scalar_efeat = vec_efeat.size() != 0;
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
// TODO(Israt): 1:Resolve PR-https://github.com/dmlc/dgl/issues/2995 // TODO(Israt): Resolve PR-https://github.com/dmlc/dgl/issues/2995 and use multistream
// to use maxstream > 1 auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); SWITCH_BITS(bits, DType, {
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { std::vector<DType*> trans_out(vec_out.size(), NULL);
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype]; bool use_legacy_cusparsemm =
CSRMatrix csr = vec_csr[etype]; (CUDART_VERSION < 11000) &&
if (reduce == "sum") { ((op == "copy_lhs" && cusparse_available<bits, IdType>()) ||
SWITCH_BITS(bits, DType, { (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>()));
/* Call SpMM for each relation type */ #if CUDART_VERSION < 11000
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
const int m = vec_out[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1];
if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out;
}
}
#endif
// Check shape of ufeat for all relation type and compute feature size
int64_t x_length = 1;
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] != next_ufeat->shape[i]) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
LOG(FATAL) <<
"Homogenized message passing on heterogeneous graphs does not support " <<
"automatic broadcasting. Please manually broadcast it before calling " <<
"message passing functions.";
else
LOG(FATAL) << "Input features have different shapes.";
return;
}
if (etype == 0)
x_length *= ufeat->shape[i];
}
}
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse
int64_t x_length = 1; /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
NDArray nd_ufeat = vec_ufeat[ufeat_ntids[0]]; DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
for (int i = 1; i < nd_ufeat->ndim; ++i) { static_cast<DType*>(vec_out[dst_id]->data);
x_length *= nd_ufeat->shape[i];
}
cusparse::CusparseCsrmm2Hetero<DType, IdType>( cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr, nullptr,
static_cast<DType*>(vec_out[dst_id]->data), out,
x_length, x_length, thr_entry->stream);
thr_entry->stream);
} else if (op == "mul" && is_scalar_efeat && } else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>()) { // cusparse cusparse_available<bits, IdType>()) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
int64_t x_length = 1; if (!IsNullArray(csr.data))
NDArray nd_ufeat = vec_ufeat[ufeat_ntids[0]]; efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data);
for (int i = 1; i < nd_ufeat->ndim; ++i) {
x_length *= nd_ufeat->shape[i]; cusparse::CusparseCsrmm2Hetero<DType, IdType>(
} csr.indptr->ctx, csr,
if (!IsNullArray(csr.data)) { static_cast<DType*>(vec_ufeat[src_id]->data),
SWITCH_BITS(bits, DType, { static_cast<DType*>(efeat->data),
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data); // TODO(Israt): Change vec_out to trans_out to support CUDA version < 11
}); static_cast<DType*>(vec_out[dst_id]->data),
} x_length, thr_entry->stream);
SWITCH_BITS(bits, DType, {
cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(vec_out[dst_id]->data),
x_length,
thr_entry->stream);
});
} else { // general kernel } else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id]; NullArray() : vec_ufeat[src_id];
...@@ -580,35 +605,49 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -580,35 +605,49 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
NullArray(), NullArray(), thr_entry->stream); NullArray(), NullArray(), thr_entry->stream);
}); });
} }
}); } else if (reduce == "max") {
} else if (reduce == "max") { // SWITCH_BITS(bits, DType, {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { NDArray ufeat = (vec_ufeat.size() == 0) ?
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ?
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NullArray() : vec_efeat[etype]; cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >( bcast, csr, ufeat, efeat, vec_out[dst_id],
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1], thr_entry->stream);
out_aux[0], out_aux[1], thr_entry->stream); });
}); // });
}); } else if (reduce == "min") {
} else if (reduce == "min") { // SWITCH_BITS(bits, DType, {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { NDArray ufeat = (vec_ufeat.size() == 0) ?
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NullArray() : vec_ufeat[src_id]; NDArray efeat = (vec_efeat.size() == 0) ?
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NullArray() : vec_efeat[etype]; cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >( bcast, csr, ufeat, efeat, vec_out[dst_id],
bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1], thr_entry->stream);
out_aux[0], out_aux[1], thr_entry->stream); // });
}); });
}); } else {
} else { LOG(FATAL) << "Not implemented";
LOG(FATAL) << "Not implemented"; }
} }
}
#if CUDART_VERSION < 11000
if (use_legacy_cusparsemm) {
// transpose output
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
const int m = vec_out[ntype]->shape[0];
const int n = vec_out[ntype]->shape[1];
if (m == 0) continue;
DType *C_data = static_cast<DType*>(vec_out[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
}
#endif
});
} }
/*! /*!
......
...@@ -160,6 +160,17 @@ __global__ void SpMMCsrKernel( ...@@ -160,6 +160,17 @@ __global__ void SpMMCsrKernel(
DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add); DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid); ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
} }
// TODO(isratnisa, BarclayII)
// The use of += is a quick hack to compute for cross-type reducing
// C = SpMM(SpA, B) + C
// To make it work on max-reducer and min-reducer, i.e.
// C = Max(SpMM<BinaryOp, Max>(SpA, B), C)
// it requires at least the following:
// 1. Initialize the output buffer with ReducerOp::zero.
// 2. Record also which edge type has the maximum/minimum in argmax/argmin.
// This requires non-trivial changes in SpMMCsrKernel itself or writing a new kernel.
// So we leave it to future PRs.
out[ty * out_len + tx] += local_accum; out[ty * out_len + tx] += local_accum;
if (ReduceOp::require_arg && BinaryOp::use_lhs) if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu; arg_u[ty * out_len + tx] = local_argu;
......
...@@ -45,7 +45,7 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -45,7 +45,7 @@ void SpMM(const std::string& op, const std::string& reduce,
op, reduce, bcast, graph->GetCOOMatrix(0), op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux); ufeat, efeat, out, out_aux);
} else { } else {
LOG(FATAL) << "SpMM only supports CSC and COO foramts"; LOG(FATAL) << "SpMM only supports CSC and COO formats";
} }
}); });
}); });
...@@ -76,8 +76,7 @@ void SpMMHetero(const std::string& op, const std::string& reduce, ...@@ -76,8 +76,7 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
NDArray ufeat = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[ufeat_eid[0]]; NDArray ufeat = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[ufeat_eid[0]];
const auto& bcast = CalcBcastOff(op, ufeat, efeat); const auto& bcast = CalcBcastOff(op, ufeat, efeat);
// TODO(Israt): Change it to ATEN_XPU_SWITCH_CUDA when cuda codes are modified ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[out_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_BITS_SWITCH(out[out_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
...@@ -85,14 +84,10 @@ void SpMMHetero(const std::string& op, const std::string& reduce, ...@@ -85,14 +84,10 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
op, reduce, bcast, vec_graph, op, reduce, bcast, vec_graph,
ufeat_vec, efeat_vec, out, out_aux, ufeat_vec, efeat_vec, out, out_aux,
ufeat_eid, out_eid); ufeat_eid, out_eid);
// TODO(Israt): Enable it when COO support is added
// } else if (format == SparseFormat::kCOO) {
// SpMMCoo<XPU, IdType, bits>(
// op, reduce, bcast, graph->GetCOOMatrix(0),
// ufeat, vec_efeat, out, out_aux);
// }
} else { } else {
LOG(FATAL) << "SpMM only supports CSC foramt for heterpgraph"; // TODO(Israt): Add support for COO format
LOG(FATAL) << "SpMM only supports CSC format for graphs with number "
<< "of relation types > 1";
} }
}); });
}); });
...@@ -124,7 +119,7 @@ void SDDMM(const std::string& op, ...@@ -124,7 +119,7 @@ void SDDMM(const std::string& op,
op, bcast, graph->GetCOOMatrix(0), op, bcast, graph->GetCOOMatrix(0),
lhs, rhs, out, lhs_target, rhs_target); lhs, rhs, out, lhs_target, rhs_target);
} else { } else {
LOG(FATAL) << "SDDMM only supports CSR and COO foramts"; LOG(FATAL) << "SDDMM only supports CSR and COO formats";
} }
}); });
}); });
...@@ -154,8 +149,7 @@ void SDDMMHetero(const std::string& op, ...@@ -154,8 +149,7 @@ void SDDMMHetero(const std::string& op,
} }
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]); const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
// TODO(Israt): change it to ATEN_XPU_SWITCH_CUDA when cuda codes are modified ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
...@@ -163,13 +157,10 @@ void SDDMMHetero(const std::string& op, ...@@ -163,13 +157,10 @@ void SDDMMHetero(const std::string& op,
op, bcast, vec_csr, op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid); lhs_eid, rhs_eid);
// TODO(Israt): Enable it when COO support is added
// } else if (format == SparseFormat::kCOO) {
// SDDMMCoo<XPU, IdType, bits>(
// op, bcast, graph->GetCOOMatrix(0),
// lhs, rhs, out, lhs_target, rhs_target);
} else { } else {
LOG(FATAL) << "SDDMM only supports CSR foramts"; // TODO(Israt): Add support for COO format
LOG(FATAL) << "SDDMM only supports CSC format for graphs with number "
<< "of relation types > 1";
} }
}); });
}); });
......
import dgl
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import test_utils
from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g
# def init_features(idtype):
@parametrize_dtype
def test_unary_copy_u(idtype):
def _test(mfunc, rfunc):
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2
#################################################################
# multi_update_all(): call msg_passing separately for each etype
#################################################################
with F.record_grad():
g.multi_update_all(
{'plays' : (mfunc('h', 'm'), rfunc('m', 'y')),
'follows': (mfunc('h', 'm'), rfunc('m', 'y')),
'develops': (mfunc('h', 'm'), rfunc('m', 'y')),
'wishes': (mfunc('h', 'm'), rfunc('m', 'y'))},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.randn(r1.shape))
n_grad1 = F.grad(g.nodes['user'].data['h'])
g.nodes['game'].data.clear()
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
g.update_all(mfunc('h', 'm'), rfunc('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.randn(r2.shape))
n_grad2 = F.grad(g.nodes['user'].data['h'])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
_test(fn.copy_u, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_u', 'max')
# _test('copy_u', 'min')
# _test('copy_u', 'mean')
@parametrize_dtype
def test_unary_copy_e(idtype):
def _test(mfunc, rfunc):
g = create_test_heterograph(idtype)
feat_size = 2
x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
#################################################################
# multi_update_all(): call msg_passing separately for each etype
#################################################################
with F.record_grad():
g.multi_update_all(
{'plays' : (mfunc('eid', 'm'), rfunc('m', 'y')),
'follows': (mfunc('eid', 'm'), rfunc('m', 'y')),
'develops': (mfunc('eid', 'm'), rfunc('m', 'y')),
'wishes': (mfunc('eid', 'm'), rfunc('m', 'y'))},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.randn(r1.shape))
e_grad1 = F.grad(g['develops'].edata['eid'])
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
# TODO(Israt): output type can be None in multi_update and empty
# tensor in new_update_all
g.update_all(mfunc('eid', 'm'), rfunc('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.randn(r2.shape))
e_grad2 = F.grad(g['develops'].edata['eid'])
# # correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(e_grad1, e_grad2):
print('edge grad')
_print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2))
_test(fn.copy_e, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_e', 'max')
# _test('copy_e', 'min')
# _test('copy_e', 'mean')
if __name__ == '__main__':
test_unary_copy_u()
test_unary_copy_e()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment