Unverified Commit 298e4fa6 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Support builtin binary message function for heterogenenous graph (#3273)



* Added binary builtinMsgFunc forward() for heterograph

* Added backward for u_op_v

* Supports all binary builtin forward

* Supports binary message funcs with reduce func sum

* lint check

* removed import torch from unittest

* enabled GPU test

* lint check

* Fixed docstrings

* rename func get_hs_id

* edited comment
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent c81efdf2
...@@ -1467,7 +1467,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -1467,7 +1467,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
""" """
pass pass
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph. r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
All the relation types of the heterogeneous graph will be processed together. All the relation types of the heterogeneous graph will be processed together.
It fuses two steps into one kernel. It fuses two steps into one kernel.
...@@ -1493,6 +1493,8 @@ def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple): ...@@ -1493,6 +1493,8 @@ def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
``copy_lhs``, ``copy_rhs``. ``copy_lhs``, ``copy_rhs``.
reduce_op : str reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``. Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_len : int
Length of the lhs data
lhs_and_rhs_tuple : tuple of tensors lhs_and_rhs_tuple : tuple of tensors
lhs_data and rhs_data are concatenated to one tuple. lhs_data is lhs_data and rhs_data are concatenated to one tuple. lhs_data is
also a tuple of tensors of size number of ntypes. Same is true for also a tuple of tensors of size number of ntypes. Same is true for
...@@ -1541,7 +1543,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -1541,7 +1543,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
""" """
pass pass
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple): def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on
heterogenenous graph. All the relation types of the heterogeneous graph heterogenenous graph. All the relation types of the heterogeneous graph
will be processed together. will be processed together.
...@@ -1562,6 +1564,8 @@ def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple): ...@@ -1562,6 +1564,8 @@ def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
op : str op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``. ``copy_lhs``, ``copy_rhs``.
lhs_len : int
Length of the lhs data
lhs_target: str lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand. Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str rhs_target: str
......
...@@ -2,7 +2,7 @@ import torch as th ...@@ -2,7 +2,7 @@ import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask from ...sparse import _csrmm, _csrsum, _csrmask, get_typeid_by_target
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
...@@ -192,45 +192,75 @@ class GSpMM(th.autograd.Function): ...@@ -192,45 +192,75 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function): class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, reduce_op, *feats): # feats = lhs_data + rhs_data def forward(ctx, g, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, feats) out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, X_len, feats)
ctx.backward_cache = g, op, reduce_op X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
src_id, dst_id = g._graph.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(X_len)])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(X_len)])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
# checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argX = None
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argY = None
ctx.save_for_backward(*feats, argX, argY) ctx.save_for_backward(*feats, argX, argY)
return out return out
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
g, op, reduce_op = ctx.backward_cache g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
feats = ctx.saved_tensors[:-2] feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2] argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1] argY = ctx.saved_tensors[-1]
num_ntypes = g._graph.number_of_ntypes() X, Y = feats[:X_len], feats[X_len:]
X, Y = feats[:num_ntypes], feats[num_ntypes:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = g.reverse() g_rev = g.reverse()
# TODO(Israt): implement other combinations of message and reduce functions # TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum': if reduce_op == 'sum':
if op == 'copy_lhs': if op == 'mul':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', *dZ) dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X[i].shape) if X[i] is not None else None elif op == 'add':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + Y))
elif op == 'copy_lhs':
tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))]) for i in range(len(X))])
else: # X has not gradient else: # X has not gradient
dX = tuple([None] * len(X)) dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]): if op != 'copy_lhs' and any([y is not None for y in Y]):
# TODO(Israt): implement other combinations of message and reduce functions # TODO(Israt): implement other combinations of reduce functions
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['copy_rhs']: tpl_dZ = tuple([dZ[i] if dZ[i] is not None else None
tmp_Z = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))]) for i in range(len(dZ))])
tmp = tuple(X + tmp_Z) tpl_X_dZ = tuple(X + tpl_dZ)
dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp) if op == 'mul' and reduce_last:
dY = tuple([_reduce_grad(dY[i], Y[i].shape) if Y[i] is not None else None dY = gsddmm_hetero(g, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
elif op == 'mul':
dY = gsddmm_hetero(g, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(g, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))]) for i in range(len(Y))])
else: # Y has no gradient else: # Y has no gradient
dY = tuple([None] * len(Y)) dY = tuple([None] * len(Y))
return (None, None, None) + dX + dY return (None, None, None, None) + dX + dY
def sddmm_cache_X(op, req_grad_X, req_grad_Y): def sddmm_cache_X(op, req_grad_X, req_grad_Y):
...@@ -315,18 +345,82 @@ class GSDDMM(th.autograd.Function): ...@@ -315,18 +345,82 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function): class GSDDMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, lhs_target, rhs_target, *feats): # feats = X+Y def forward(ctx, g, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, lhs_target, rhs_target, feats) out = _gsddmm_hetero(g, op, X_len, lhs_target, rhs_target, feats)
ctx.backward_cache = g, op, lhs_target, rhs_target X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(len(X))])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
ctx.backward_cache = g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(len(X))])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
lhs_id = get_typeid_by_target(g, g.canonical_etypes[0], lhs_target)
rhs_id = get_typeid_by_target(g, g.canonical_etypes[0], rhs_target)
ctx.save_for_backward(*feats) ctx.save_for_backward(*feats)
return out return out
@staticmethod @staticmethod
@custom_bwd @custom_bwd
# TODO(Israt): Implement the backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
raise NotImplementedError('Homogenized GSDDMM backward operation is not implemented.') g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
if lhs_target in ['u', 'v']:
_g = g if lhs_target == 'v' else g.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
elif rhs_target == 'e':
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))])
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
else: # rhs_target = !lhs_target
dX = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(Y + dZ))
else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']:
dX = dZ
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dX = gsddmm_hetero(g, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
else:
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']:
_g = g if rhs_target == 'v' else g.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
elif lhs_target == 'e':
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))])
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
else: # rhs_target = !lhs_target
dY = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(X + dZ))
else:
if op in ['add', 'copy_rhs']:
dY = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dY = gsddmm_hetero(g, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else:
dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY
class EdgeSoftmax(th.autograd.Function): class EdgeSoftmax(th.autograd.Function):
@staticmethod @staticmethod
...@@ -502,11 +596,33 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -502,11 +596,33 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
rhs_data = 1. / rhs_data rhs_data = 1. / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target) return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple): def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
return GSpMM_hetero.apply(g, op, reduce_op, *lhs_and_rhs_tuple) lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple): op = 'add'
return GSDDMM_hetero.apply(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple) rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
op = 'add'
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by) return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
......
...@@ -184,7 +184,7 @@ def _bucketing(val): ...@@ -184,7 +184,7 @@ def _bucketing(val):
return bkts return bkts
return unique_val, bucketor return unique_val, bucketor
def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None): def data_dict_to_list(graph, data_dict, func, target):
"""Get node or edge feature data of the given name for all the types. """Get node or edge feature data of the given name for all the types.
Parameters Parameters
...@@ -194,29 +194,46 @@ def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None): ...@@ -194,29 +194,46 @@ def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None):
data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]] data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]]
Node or edge data stored in DGLGraph. The key of the dictionary Node or edge data stored in DGLGraph. The key of the dictionary
is the node type name or edge type name. is the node type name or edge type name.
op : str func : dgl.function.BaseMessageFunction
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``, Built-in message function.
``copy_lhs``, ``copy_rhs``. target : 'u', 'v' or 'e'
lhs_list : list[tensor] or list[None] The target of the lhs or rhs data
The feature on source nodes, could be list of None if op is ``copy_rhs``.
rhs_list : list[tensor] or list[None]
The feature on edges, could be list of None if op is ``copy_lhs``.
Returns Returns
-------- --------
data_tuple : tuple(Tensor) data_list : list(Tensor)
Feature data stored in tuple of tensors. The i^th tensor stores the feature Feature data stored in a list of tensors. The i^th tensor stores the feature
data of type ``types[i]``. data of type ``types[i]``.
""" """
if op == "copy_u": if isinstance(func, fn.BinaryMessageFunction):
if target in ['u', 'v']:
output_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, dsttype in graph.canonical_etypes:
if target == 'u':
src_id = graph.get_ntype_id(srctype)
output_list[src_id] = data_dict[srctype]
else:
dst_id = graph.get_ntype_id(dsttype)
output_list[dst_id] = data_dict[dsttype]
else: # target == 'e'
output_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
output_list[etid] = data_dict[rel]
return output_list
else:
if target == 'u':
lhs_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, _ in graph.canonical_etypes: for srctype, _, _ in graph.canonical_etypes:
src_id = graph.get_ntype_id(srctype) src_id = graph.get_ntype_id(srctype)
lhs_list[src_id] = data_dict[srctype] lhs_list[src_id] = data_dict[srctype]
elif op == "copy_e": return lhs_list
else: # target == 'e':
rhs_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes: for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel) etid = graph.get_etype_id(rel)
rhs_list[etid] = data_dict[rel] rhs_list[etid] = data_dict[rel]
return tuple(lhs_list + rhs_list) return rhs_list
def invoke_gsddmm(graph, func): def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph. """Invoke g-SDDMM computation on the graph.
...@@ -238,10 +255,20 @@ def invoke_gsddmm(graph, func): ...@@ -238,10 +255,20 @@ def invoke_gsddmm(graph, func):
x = alldata[func.lhs][func.lhs_field] x = alldata[func.lhs][func.lhs_field]
y = alldata[func.rhs][func.rhs_field] y = alldata[func.rhs][func.rhs_field]
op = getattr(ops, func.name) op = getattr(ops, func.name)
if graph._graph.number_of_etypes() > 1:
lhs_target, _, rhs_target = func.name.split("_", 2)
x = data_dict_to_list(graph, x, func, lhs_target)
y = data_dict_to_list(graph, y, func, rhs_target)
z = op(graph, x, y) z = op(graph, x, y)
else: else:
x = alldata[func.target][func.in_field] x = alldata[func.target][func.in_field]
op = getattr(ops, func.name) op = getattr(ops, func.name)
if graph._graph.number_of_etypes() > 1:
# Convert to list as dict is unordered.
if func.name == "copy_u":
x = data_dict_to_list(graph, x, func, 'u')
else: # "copy_e"
x = data_dict_to_list(graph, x, func, 'e')
z = op(graph, x) z = op(graph, x)
return {func.out_field : z} return {func.out_field : z}
...@@ -285,15 +312,19 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None) ...@@ -285,15 +312,19 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
x = alldata[mfunc.lhs][mfunc.lhs_field] x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field] y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name)) op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1:
lhs_target, _, rhs_target = mfunc.name.split("_", 2)
x = data_dict_to_list(graph, x, mfunc, lhs_target)
y = data_dict_to_list(graph, y, mfunc, rhs_target)
z = op(graph, x, y) z = op(graph, x, y)
else: else:
x = alldata[mfunc.target][mfunc.in_field] x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name)) op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1: if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple):
# Convert to list as dict is unordered. if mfunc.name == "copy_u":
lhs_list = [None] * graph._graph.number_of_ntypes() x = data_dict_to_list(graph, x, mfunc, 'u')
rhs_list = [None] * graph._graph.number_of_etypes() else: # "copy_e"
x = data_dict_to_tuple(graph, x, mfunc.name, lhs_list, rhs_list) x = data_dict_to_list(graph, x, mfunc, 'e')
z = op(graph, x) z = op(graph, x)
return {rfunc.out_field : z} return {rfunc.out_field : z}
......
...@@ -4861,10 +4861,6 @@ class DGLHeteroGraph(object): ...@@ -4861,10 +4861,6 @@ class DGLHeteroGraph(object):
raise NotImplementedError("Cannot set both intra-type and inter-type reduce " raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use " "operators as 'mean' using update_all. Please use "
"multi_update_all instead.") "multi_update_all instead.")
if message_func.name not in ['copy_u', 'copy_e']:
raise NotImplementedError("Op \'" + message_func.name + "\' is not yet supported"
"in update_all for heterogeneous graphs. Please use"
"multi_update_all instead.")
g = self g = self
all_out = core.message_passing(g, message_func, reduce_func, apply_node_func) all_out = core.message_passing(g, message_func, reduce_func, apply_node_func)
key = list(all_out.keys())[0] key = list(all_out.keys())[0]
......
...@@ -74,22 +74,13 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -74,22 +74,13 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return gsddmm_internal( return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target) g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
else: else:
lhs_data_dict = lhs_data # TODO (Israt): Call reshape_lhs_rhs() on lhs and rhs data to match their dimension
rhs_data_dict = rhs_data # and avoid broadcasting issue. Handle the case where different nodes have
lhs_list = [None] * g._graph.number_of_ntypes() # different dimensions, and different etypes may need different broadcasting
rhs_list = [None] * g._graph.number_of_ntypes() # dims for the same node.
for srctype, _, dsttype in g.canonical_etypes: lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
src_id = g.get_ntype_id(srctype) return gsddmm_internal_hetero(g, op, len(lhs_data), lhs_target,
dst_id = g.get_ntype_id(dsttype) rhs_target, *lhs_and_rhs_tuple)
lhs_data = lhs_data_dict[srctype]
rhs_data = rhs_data_dict[dsttype]
if op not in ['copy_lhs', 'copy_rhs']:
lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)
lhs_list[src_id] = lhs_data
rhs_list[dst_id] = rhs_data
lhs_and_rhs_tuple = tuple(lhs_list + rhs_list)
# With max and min reducers infinity will be returned for zero degree nodes
return gsddmm_internal_hetero(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op): def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target) name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
......
...@@ -79,14 +79,15 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -79,14 +79,15 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
if reduce_op in ['min', 'max']: if reduce_op in ['min', 'max']:
ret = F.replace_inf_with_zero(ret) ret = F.replace_inf_with_zero(ret)
else: else:
if op in ['copy_lhs', 'copy_rhs']: # lhs_data or rhs_data is None only in unary functions like ``copy-u`` or ``copy_e``
lhs_and_rhs_tuple = lhs_data if rhs_data is None else rhs_data lhs_data = [None] * g._graph.number_of_ntypes() if lhs_data is None else lhs_data
rhs_data = [None] * g._graph.number_of_etypes() if rhs_data is None else rhs_data
# TODO (Israt): Call reshape func
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
ret = gspmm_internal_hetero(g, op, ret = gspmm_internal_hetero(g, op,
'sum' if reduce_op == 'mean' else reduce_op, 'sum' if reduce_op == 'mean' else reduce_op,
*lhs_and_rhs_tuple) len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph # TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
# divide in degrees for mean reducer. # divide in degrees for mean reducer.
if reduce_op == 'mean': if reduce_op == 'mean':
ret_shape = F.shape(ret) ret_shape = F.shape(ret)
......
...@@ -64,6 +64,17 @@ def to_dgl_nd_for_write(x): ...@@ -64,6 +64,17 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x) return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
def get_typeid_by_target(g, rel, target):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
if target in [0, 'u']:
return g.get_ntype_id(srctype)
if target in [2, 'v']:
return g.get_ntype_id(dsttype)
return etid
target_mapping = { target_mapping = {
'u': 0, 'u': 0,
'e': 1, 'e': 1,
...@@ -179,15 +190,13 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -179,15 +190,13 @@ def _gspmm(gidx, op, reduce_op, u, e):
return v, (arg_u, arg_e) return v, (arg_u, arg_e)
def _gspmm_hetero(g, op, reduce_op, u_and_e_tuple): def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface. r""" Generalized Sparse Matrix Multiplication interface.
""" """
num_ntypes = g._graph.number_of_ntypes() u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
u_tuple, e_tuple = u_and_e_tuple[:num_ntypes], u_and_e_tuple[num_ntypes:]
gidx = g._graph gidx = g._graph
use_u = op != 'copy_rhs' use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs' use_e = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features. # deal with scalar features.
...@@ -337,33 +346,33 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -337,33 +346,33 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out return out
def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None): def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
""" """
num_ntypes = g._graph.number_of_ntypes()
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:num_ntypes], lhs_and_rhs_tuple[num_ntypes:]
gidx = g._graph gidx = g._graph
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
use_lhs = op != 'copy_rhs' use_lhs = op != 'copy_rhs'
use_rhs = op != 'copy_lhs' use_rhs = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features. # deal with scalar features.
expand_lhs, expand_rhs = False, False expand_lhs, expand_rhs = False, False
num_ntype = g._graph.number_of_ntypes()
num_etype = g._graph.number_of_etypes()
lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype
rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype
out_list = [None] * gidx.number_of_etypes()
lhs_target = target_mapping[lhs_target] lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target] rhs_target = target_mapping[rhs_target]
lhs_list = [None] * gidx.number_of_ntypes()
rhs_list = [None] * gidx.number_of_ntypes()
out_list = [None] * gidx.number_of_etypes()
for rel in g.canonical_etypes: for rel in g.canonical_etypes:
srctype, _, dsttype = rel
etid = g.get_etype_id(rel) etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype) lhs_id = get_typeid_by_target(g, rel, lhs_target)
dst_id = g.get_ntype_id(dsttype) rhs_id = get_typeid_by_target(g, rel, rhs_target)
lhs = lhs_tuple[src_id] lhs = lhs_tuple[lhs_id]
rhs = rhs_tuple[dst_id] rhs = rhs_tuple[rhs_id]
if use_lhs: if use_lhs:
if lhs is not None and F.ndim(lhs) == 1: if lhs is not None and F.ndim(lhs) == 1:
lhs = F.unsqueeze(lhs, -1) lhs = F.unsqueeze(lhs, -1)
...@@ -372,17 +381,15 @@ def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None ...@@ -372,17 +381,15 @@ def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None
if rhs is not None and F.ndim(rhs) == 1: if rhs is not None and F.ndim(rhs) == 1:
rhs = F.unsqueeze(lhs, -1) rhs = F.unsqueeze(lhs, -1)
expand_rhs = True expand_rhs = True
ctx = F.context(lhs) if use_lhs else F.context(rhs) ctx = F.context(lhs) if use_lhs else F.context(rhs)
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs) dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,) lhs_shp = F.shape(lhs) if use_lhs else (0,)
rhs_shp = F.shape(rhs) if use_rhs else (0,) rhs_shp = F.shape(rhs) if use_rhs else (0,)
lhs_list[src_id] = lhs if use_lhs else None lhs_list[lhs_id] = lhs if use_lhs else None
rhs_list[dst_id] = rhs if use_rhs else None rhs_list[rhs_id] = rhs if use_rhs else None
out_shp = (gidx.number_of_edges(etid), ) +\ out_shp = (gidx.number_of_edges(etid), ) +\
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:]) infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:])
out_list[etid] = F.zeros(out_shp, dtype, ctx) out_list[etid] = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0: if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(gidx, op, _CAPI_DGLKernelSDDMMHetero(gidx, op,
[to_dgl_nd(lhs) for lhs in lhs_list], [to_dgl_nd(lhs) for lhs in lhs_list],
......
...@@ -519,7 +519,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -519,7 +519,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<NDArray>& out_aux, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat.size() != 0; bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
// TODO(Israt): Resolve PR-https://github.com/dmlc/dgl/issues/2995 and use multistream // TODO(Israt): Resolve PR-https://github.com/dmlc/dgl/issues/2995 and use multistream
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
...@@ -589,8 +589,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -589,8 +589,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
cusparse_available<bits, IdType>(more_nnz)) { // cusparse cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data); efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
cusparse::CusparseCsrmm2Hetero<DType, IdType>( cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
......
...@@ -127,6 +127,23 @@ void SDDMM(const std::string& op, ...@@ -127,6 +127,23 @@ void SDDMM(const std::string& op,
} }
/*!
* \brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'.
*
* \param graph The input graph.
* \param target 'u', 'v' or 'e'. The target of the lhs or rhs data of an etype.
* \param etype Relation type of the input graph.
*/
int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
if (target == 0)
return pair.first;
if (target == 2)
return pair.second;
return etype;
}
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op, void SDDMMHetero(const std::string& op,
HeteroGraphPtr graph, HeteroGraphPtr graph,
...@@ -143,9 +160,8 @@ void SDDMMHetero(const std::string& op, ...@@ -143,9 +160,8 @@ void SDDMMHetero(const std::string& op,
std::vector<dgl_type_t> rhs_eid; std::vector<dgl_type_t> rhs_eid;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype)); vec_csr.push_back(graph->GetCSRMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype); lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));
lhs_eid.push_back(pair.first); rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));
rhs_eid.push_back(pair.second);
} }
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]); const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
......
...@@ -4,6 +4,7 @@ from collections import Counter ...@@ -4,6 +4,7 @@ from collections import Counter
import numpy as np import numpy as np
import scipy.sparse as ssp import scipy.sparse as ssp
import itertools import itertools
from itertools import product
import backend as F import backend as F
import networkx as nx import networkx as nx
import unittest, pytest import unittest, pytest
...@@ -37,9 +38,6 @@ def create_test_heterograph(idtype): ...@@ -37,9 +38,6 @@ def create_test_heterograph(idtype):
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
# def init_features(idtype):
@parametrize_dtype @parametrize_dtype
def test_unary_copy_u(idtype): def test_unary_copy_u(idtype):
def _test(mfunc, rfunc): def _test(mfunc, rfunc):
...@@ -167,8 +165,89 @@ def test_unary_copy_e(idtype): ...@@ -167,8 +165,89 @@ def test_unary_copy_e(idtype):
# _test('copy_e', 'mean') # _test('copy_e', 'mean')
@parametrize_dtype
def test_binary_op(idtype):
def _test(lhs, rhs, binary_op, reducer):
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))
x3 = F.randn((g.num_nodes('game'), feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2
g.nodes['game'].data['h'] = x3
x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name)
builtin_red = getattr(fn, reducer)
#################################################################
# multi_update_all(): call msg_passing separately for each etype
#################################################################
with F.record_grad():
g.multi_update_all(
{etype : (builtin_msg('h', 'h', 'm'), builtin_red('m', 'y'))
for etype in g.canonical_etypes},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.ones(r1.shape))
n_grad1 = F.grad(r1)
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
g.update_all(builtin_msg('h', 'h', 'm'), builtin_red('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.ones(r2.shape))
n_grad2 = F.grad(r2)
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
# TODO (Israt): r1 and r2 have different frad func associated with
# if not F.allclose(n_grad1, n_grad2):
# print('node grad')
# _print_error(n_grad1, n_grad2)
# assert(F.allclose(n_grad1, n_grad2))
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div"]:
# TODO(Israt) :Add support for reduce func "max", "min", "mean"
for reducer in ["sum"]:
print(lhs, rhs, binary_op, reducer)
_test(lhs, rhs, binary_op, reducer)
if __name__ == '__main__': if __name__ == '__main__':
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
test_binary_op()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment