Unverified Commit 298e4fa6 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Support builtin binary message function for heterogenenous graph (#3273)



* Added binary builtinMsgFunc forward() for heterograph

* Added backward for u_op_v

* Supports all binary builtin forward

* Supports binary message funcs with reduce func sum

* lint check

* removed import torch from unittest

* enabled GPU test

* lint check

* Fixed docstrings

* rename func get_hs_id

* edited comment
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent c81efdf2
......@@ -1467,7 +1467,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
"""
pass
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
r""" Generalized Sparse Matrix Multiplication interface on heterogenenous graph.
All the relation types of the heterogeneous graph will be processed together.
It fuses two steps into one kernel.
......@@ -1493,6 +1493,8 @@ def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
``copy_lhs``, ``copy_rhs``.
reduce_op : str
Reduce operator, could be ``sum``, ``max``, ``min``.
lhs_len : int
Length of the lhs data
lhs_and_rhs_tuple : tuple of tensors
lhs_data and rhs_data are concatenated to one tuple. lhs_data is
also a tuple of tensors of size number of ntypes. Same is true for
......@@ -1541,7 +1543,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
"""
pass
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface on
heterogenenous graph. All the relation types of the heterogeneous graph
will be processed together.
......@@ -1562,6 +1564,8 @@ def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
op : str
Binary operator, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_len : int
Length of the lhs data
lhs_target: str
Choice of `u`(source), `e`(edge) or `v`(destination) for left operand.
rhs_target: str
......
......@@ -2,7 +2,7 @@ import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask
from ...sparse import _csrmm, _csrsum, _csrmask, get_typeid_by_target
from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
......@@ -192,45 +192,75 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, reduce_op, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, feats)
ctx.backward_cache = g, op, reduce_op
def forward(ctx, g, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, X_len, feats)
X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
src_id, dst_id = g._graph.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(X_len)])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(X_len)])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
# checking the first relation to decide for all the relations
if not spmm_cache_argX(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argX = None
if not spmm_cache_argY(op, reduce_op, req_grad_X[src_id], req_grad_Y[dst_id]):
argY = None
ctx.save_for_backward(*feats, argX, argY)
return out
@staticmethod
@custom_bwd
def backward(ctx, *dZ):
g, op, reduce_op = ctx.backward_cache
g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1]
num_ntypes = g._graph.number_of_ntypes()
X, Y = feats[:num_ntypes], feats[num_ntypes:]
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = g.reverse()
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op == 'copy_lhs':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', *dZ)
dX = tuple([_reduce_grad(dX[i], X[i].shape) if X[i] is not None else None
if op == 'mul':
dX = gspmm_hetero(g_rev, 'mul', 'sum', len(X), *tuple(dZ + Y))
elif op == 'add':
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + Y))
elif op == 'copy_lhs':
tpl_None = tuple([None] * len(Y))
dX = gspmm_hetero(g_rev, 'copy_lhs', 'sum', len(X), *tuple(dZ + tpl_None))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
else: # X has not gradient
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
# TODO(Israt): implement other combinations of message and reduce functions
# TODO(Israt): implement other combinations of reduce functions
if reduce_op == 'sum':
if op in ['copy_rhs']:
tmp_Z = tuple([dZ[i] if dZ[i] is not None else None
tpl_dZ = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
tmp = tuple(X + tmp_Z)
dY = gsddmm_hetero(g, 'copy_rhs', 'u', 'v', *tmp)
dY = tuple([_reduce_grad(dY[i], Y[i].shape) if Y[i] is not None else None
tpl_X_dZ = tuple(X + tpl_dZ)
if op == 'mul' and reduce_last:
dY = gsddmm_hetero(g, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
elif op == 'mul':
dY = gsddmm_hetero(g, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(g, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else: # Y has no gradient
dY = tuple([None] * len(Y))
return (None, None, None) + dX + dY
return (None, None, None, None) + dX + dY
def sddmm_cache_X(op, req_grad_X, req_grad_Y):
......@@ -315,18 +345,82 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, lhs_target, rhs_target, feats)
ctx.backward_cache = g, op, lhs_target, rhs_target
def forward(ctx, g, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, X_len, lhs_target, rhs_target, feats)
X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(len(X))])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
ctx.backward_cache = g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(len(X))])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
lhs_id = get_typeid_by_target(g, g.canonical_etypes[0], lhs_target)
rhs_id = get_typeid_by_target(g, g.canonical_etypes[0], rhs_target)
ctx.save_for_backward(*feats)
return out
@staticmethod
@custom_bwd
# TODO(Israt): Implement the backward operator
# TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ):
raise NotImplementedError('Homogenized GSDDMM backward operation is not implemented.')
g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
if lhs_target in ['u', 'v']:
_g = g if lhs_target == 'v' else g.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
elif rhs_target == 'e':
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))])
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
else: # rhs_target = !lhs_target
dX = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(Y + dZ))
else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']:
dX = dZ
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dX = gsddmm_hetero(g, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
else:
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']:
_g = g if rhs_target == 'v' else g.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
elif lhs_target == 'e':
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))])
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
else: # rhs_target = !lhs_target
dY = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(X + dZ))
else:
if op in ['add', 'copy_rhs']:
dY = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dY = gsddmm_hetero(g, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else:
dY = tuple([None] * len(Y))
return (None, None, None, None, None) + dX + dY
class EdgeSoftmax(th.autograd.Function):
@staticmethod
......@@ -502,11 +596,33 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
rhs_data = 1. / rhs_data
return GSDDMM.apply(gidx, op, lhs_data, rhs_data, lhs_target, rhs_target)
def gspmm_hetero(g, op, reduce_op, *lhs_and_rhs_tuple):
return GSpMM_hetero.apply(g, op, reduce_op, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
return GSDDMM_hetero.apply(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
op = 'add'
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSpMM_hetero.apply(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple)
def gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', *lhs_and_rhs_tuple):
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
if op == 'sub':
op = 'add'
rhs_tuple = tuple([-rhs_tuple[i] if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op == 'div':
op = 'mul'
rhs_tuple = tuple([(1. / rhs_tuple[i]) if rhs_tuple[i] is not None else None
for i in range(len(rhs_tuple))])
if op in ['add', 'mul']:
lhs_and_rhs_tuple = tuple(list(lhs_tuple) + list(rhs_tuple))
return GSDDMM_hetero.apply(g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple)
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
......
......@@ -184,7 +184,7 @@ def _bucketing(val):
return bkts
return unique_val, bucketor
def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None):
def data_dict_to_list(graph, data_dict, func, target):
"""Get node or edge feature data of the given name for all the types.
Parameters
......@@ -194,29 +194,46 @@ def data_dict_to_tuple(graph, data_dict, op, lhs_list=None, rhs_list=None):
data_dict : dict[str, Tensor] or dict[(str, str, str), Tensor]]
Node or edge data stored in DGLGraph. The key of the dictionary
is the node type name or edge type name.
op : str
The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``dot``,
``copy_lhs``, ``copy_rhs``.
lhs_list : list[tensor] or list[None]
The feature on source nodes, could be list of None if op is ``copy_rhs``.
rhs_list : list[tensor] or list[None]
The feature on edges, could be list of None if op is ``copy_lhs``.
func : dgl.function.BaseMessageFunction
Built-in message function.
target : 'u', 'v' or 'e'
The target of the lhs or rhs data
Returns
--------
data_tuple : tuple(Tensor)
Feature data stored in tuple of tensors. The i^th tensor stores the feature
data_list : list(Tensor)
Feature data stored in a list of tensors. The i^th tensor stores the feature
data of type ``types[i]``.
"""
if op == "copy_u":
if isinstance(func, fn.BinaryMessageFunction):
if target in ['u', 'v']:
output_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, dsttype in graph.canonical_etypes:
if target == 'u':
src_id = graph.get_ntype_id(srctype)
output_list[src_id] = data_dict[srctype]
else:
dst_id = graph.get_ntype_id(dsttype)
output_list[dst_id] = data_dict[dsttype]
else: # target == 'e'
output_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
output_list[etid] = data_dict[rel]
return output_list
else:
if target == 'u':
lhs_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, _ in graph.canonical_etypes:
src_id = graph.get_ntype_id(srctype)
lhs_list[src_id] = data_dict[srctype]
elif op == "copy_e":
return lhs_list
else: # target == 'e':
rhs_list = [None] * graph._graph.number_of_etypes()
for rel in graph.canonical_etypes:
etid = graph.get_etype_id(rel)
rhs_list[etid] = data_dict[rel]
return tuple(lhs_list + rhs_list)
return rhs_list
def invoke_gsddmm(graph, func):
"""Invoke g-SDDMM computation on the graph.
......@@ -238,10 +255,20 @@ def invoke_gsddmm(graph, func):
x = alldata[func.lhs][func.lhs_field]
y = alldata[func.rhs][func.rhs_field]
op = getattr(ops, func.name)
if graph._graph.number_of_etypes() > 1:
lhs_target, _, rhs_target = func.name.split("_", 2)
x = data_dict_to_list(graph, x, func, lhs_target)
y = data_dict_to_list(graph, y, func, rhs_target)
z = op(graph, x, y)
else:
x = alldata[func.target][func.in_field]
op = getattr(ops, func.name)
if graph._graph.number_of_etypes() > 1:
# Convert to list as dict is unordered.
if func.name == "copy_u":
x = data_dict_to_list(graph, x, func, 'u')
else: # "copy_e"
x = data_dict_to_list(graph, x, func, 'e')
z = op(graph, x)
return {func.out_field : z}
......@@ -285,15 +312,19 @@ def invoke_gspmm(graph, mfunc, rfunc, *, srcdata=None, dstdata=None, edata=None)
x = alldata[mfunc.lhs][mfunc.lhs_field]
y = alldata[mfunc.rhs][mfunc.rhs_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1:
lhs_target, _, rhs_target = mfunc.name.split("_", 2)
x = data_dict_to_list(graph, x, mfunc, lhs_target)
y = data_dict_to_list(graph, y, mfunc, rhs_target)
z = op(graph, x, y)
else:
x = alldata[mfunc.target][mfunc.in_field]
op = getattr(ops, '{}_{}'.format(mfunc.name, rfunc.name))
if graph._graph.number_of_etypes() > 1:
# Convert to list as dict is unordered.
lhs_list = [None] * graph._graph.number_of_ntypes()
rhs_list = [None] * graph._graph.number_of_etypes()
x = data_dict_to_tuple(graph, x, mfunc.name, lhs_list, rhs_list)
if graph._graph.number_of_etypes() > 1 and not isinstance(x, tuple):
if mfunc.name == "copy_u":
x = data_dict_to_list(graph, x, mfunc, 'u')
else: # "copy_e"
x = data_dict_to_list(graph, x, mfunc, 'e')
z = op(graph, x)
return {rfunc.out_field : z}
......
......@@ -4861,10 +4861,6 @@ class DGLHeteroGraph(object):
raise NotImplementedError("Cannot set both intra-type and inter-type reduce "
"operators as 'mean' using update_all. Please use "
"multi_update_all instead.")
if message_func.name not in ['copy_u', 'copy_e']:
raise NotImplementedError("Op \'" + message_func.name + "\' is not yet supported"
"in update_all for heterogeneous graphs. Please use"
"multi_update_all instead.")
g = self
all_out = core.message_passing(g, message_func, reduce_func, apply_node_func)
key = list(all_out.keys())[0]
......
......@@ -74,22 +74,13 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
else:
lhs_data_dict = lhs_data
rhs_data_dict = rhs_data
lhs_list = [None] * g._graph.number_of_ntypes()
rhs_list = [None] * g._graph.number_of_ntypes()
for srctype, _, dsttype in g.canonical_etypes:
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
lhs_data = lhs_data_dict[srctype]
rhs_data = rhs_data_dict[dsttype]
if op not in ['copy_lhs', 'copy_rhs']:
lhs_data, rhs_data = reshape_lhs_rhs(lhs_data, rhs_data)
lhs_list[src_id] = lhs_data
rhs_list[dst_id] = rhs_data
lhs_and_rhs_tuple = tuple(lhs_list + rhs_list)
# With max and min reducers infinity will be returned for zero degree nodes
return gsddmm_internal_hetero(g, op, lhs_target, rhs_target, *lhs_and_rhs_tuple)
# TODO (Israt): Call reshape_lhs_rhs() on lhs and rhs data to match their dimension
# and avoid broadcasting issue. Handle the case where different nodes have
# different dimensions, and different etypes may need different broadcasting
# dims for the same node.
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
return gsddmm_internal_hetero(g, op, len(lhs_data), lhs_target,
rhs_target, *lhs_and_rhs_tuple)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
name = "{}_{}_{}".format(lhs_target, binary_op, rhs_target)
......
......@@ -79,14 +79,15 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
if reduce_op in ['min', 'max']:
ret = F.replace_inf_with_zero(ret)
else:
if op in ['copy_lhs', 'copy_rhs']:
lhs_and_rhs_tuple = lhs_data if rhs_data is None else rhs_data
# lhs_data or rhs_data is None only in unary functions like ``copy-u`` or ``copy_e``
lhs_data = [None] * g._graph.number_of_ntypes() if lhs_data is None else lhs_data
rhs_data = [None] * g._graph.number_of_etypes() if rhs_data is None else rhs_data
# TODO (Israt): Call reshape func
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
ret = gspmm_internal_hetero(g, op,
'sum' if reduce_op == 'mean' else reduce_op,
*lhs_and_rhs_tuple)
len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
# divide in degrees for mean reducer.
if reduce_op == 'mean':
ret_shape = F.shape(ret)
......
......@@ -64,6 +64,17 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
def get_typeid_by_target(g, rel, target):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
if target in [0, 'u']:
return g.get_ntype_id(srctype)
if target in [2, 'v']:
return g.get_ntype_id(dsttype)
return etid
target_mapping = {
'u': 0,
'e': 1,
......@@ -179,15 +190,13 @@ def _gspmm(gidx, op, reduce_op, u, e):
return v, (arg_u, arg_e)
def _gspmm_hetero(g, op, reduce_op, u_and_e_tuple):
def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface.
"""
num_ntypes = g._graph.number_of_ntypes()
u_tuple, e_tuple = u_and_e_tuple[:num_ntypes], u_and_e_tuple[num_ntypes:]
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
gidx = g._graph
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
......@@ -337,33 +346,33 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out
def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
"""
num_ntypes = g._graph.number_of_ntypes()
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:num_ntypes], lhs_and_rhs_tuple[num_ntypes:]
gidx = g._graph
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
use_lhs = op != 'copy_rhs'
use_rhs = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
expand_lhs, expand_rhs = False, False
num_ntype = g._graph.number_of_ntypes()
num_etype = g._graph.number_of_etypes()
lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype
rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype
out_list = [None] * gidx.number_of_etypes()
lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target]
lhs_list = [None] * gidx.number_of_ntypes()
rhs_list = [None] * gidx.number_of_ntypes()
out_list = [None] * gidx.number_of_etypes()
for rel in g.canonical_etypes:
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
lhs = lhs_tuple[src_id]
rhs = rhs_tuple[dst_id]
lhs_id = get_typeid_by_target(g, rel, lhs_target)
rhs_id = get_typeid_by_target(g, rel, rhs_target)
lhs = lhs_tuple[lhs_id]
rhs = rhs_tuple[rhs_id]
if use_lhs:
if lhs is not None and F.ndim(lhs) == 1:
lhs = F.unsqueeze(lhs, -1)
......@@ -372,17 +381,15 @@ def _gsddmm_hetero(g, op, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None
if rhs is not None and F.ndim(rhs) == 1:
rhs = F.unsqueeze(lhs, -1)
expand_rhs = True
ctx = F.context(lhs) if use_lhs else F.context(rhs)
dtype = F.dtype(lhs) if use_lhs else F.dtype(rhs)
lhs_shp = F.shape(lhs) if use_lhs else (0,)
rhs_shp = F.shape(rhs) if use_rhs else (0,)
lhs_list[src_id] = lhs if use_lhs else None
rhs_list[dst_id] = rhs if use_rhs else None
lhs_list[lhs_id] = lhs if use_lhs else None
rhs_list[rhs_id] = rhs if use_rhs else None
out_shp = (gidx.number_of_edges(etid), ) +\
infer_broadcast_shape(op, lhs_shp[1:], rhs_shp[1:])
out_list[etid] = F.zeros(out_shp, dtype, ctx)
if gidx.number_of_edges(0) > 0:
_CAPI_DGLKernelSDDMMHetero(gidx, op,
[to_dgl_nd(lhs) for lhs in lhs_list],
......
......@@ -519,7 +519,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat.size() != 0;
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs";
// TODO(Israt): Resolve PR-https://github.com/dmlc/dgl/issues/2995 and use multistream
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
......@@ -589,8 +589,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data);
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
cusparse::CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
......
......@@ -127,6 +127,23 @@ void SDDMM(const std::string& op,
}
/*!
* \brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'.
*
* \param graph The input graph.
* \param target 'u', 'v' or 'e'. The target of the lhs or rhs data of an etype.
* \param etype Relation type of the input graph.
*/
int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
if (target == 0)
return pair.first;
if (target == 2)
return pair.second;
return etype;
}
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op,
HeteroGraphPtr graph,
......@@ -143,9 +160,8 @@ void SDDMMHetero(const std::string& op,
std::vector<dgl_type_t> rhs_eid;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype);
lhs_eid.push_back(pair.first);
rhs_eid.push_back(pair.second);
lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));
rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));
}
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
......
......@@ -4,6 +4,7 @@ from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
from itertools import product
import backend as F
import networkx as nx
import unittest, pytest
......@@ -37,9 +38,6 @@ def create_test_heterograph(idtype):
assert g.device == F.ctx()
return g
# def init_features(idtype):
@parametrize_dtype
def test_unary_copy_u(idtype):
def _test(mfunc, rfunc):
......@@ -167,8 +165,89 @@ def test_unary_copy_e(idtype):
# _test('copy_e', 'mean')
@parametrize_dtype
def test_binary_op(idtype):
def _test(lhs, rhs, binary_op, reducer):
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))
x3 = F.randn((g.num_nodes('game'), feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2
g.nodes['game'].data['h'] = x3
x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name)
builtin_red = getattr(fn, reducer)
#################################################################
# multi_update_all(): call msg_passing separately for each etype
#################################################################
with F.record_grad():
g.multi_update_all(
{etype : (builtin_msg('h', 'h', 'm'), builtin_red('m', 'y'))
for etype in g.canonical_etypes},
'sum')
r1 = g.nodes['game'].data['y']
F.backward(r1, F.ones(r1.shape))
n_grad1 = F.grad(r1)
#################################################################
# update_all(): call msg_passing for all etypes
#################################################################
g.update_all(builtin_msg('h', 'h', 'm'), builtin_red('m', 'y'))
r2 = g.nodes['game'].data['y']
F.backward(r2, F.ones(r2.shape))
n_grad2 = F.grad(r2)
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
# TODO (Israt): r1 and r2 have different frad func associated with
# if not F.allclose(n_grad1, n_grad2):
# print('node grad')
# _print_error(n_grad1, n_grad2)
# assert(F.allclose(n_grad1, n_grad2))
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs == rhs:
continue
for binary_op in ["add", "sub", "mul", "div"]:
# TODO(Israt) :Add support for reduce func "max", "min", "mean"
for reducer in ["sum"]:
print(lhs, rhs, binary_op, reducer)
_test(lhs, rhs, binary_op, reducer)
if __name__ == '__main__':
test_unary_copy_u()
test_unary_copy_e()
test_binary_op()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment