"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ff72433fa5a4d9f9e2f2c599e394480b581c614"
Unverified Commit 532eaa87 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

backward now stores DGLGraph index,not DGLGraph object witattached data (#3410)


Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
parent aef96dfa
...@@ -2,7 +2,7 @@ import torch as th ...@@ -2,7 +2,7 @@ import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask, get_typeid_by_target from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
...@@ -192,12 +192,12 @@ class GSpMM(th.autograd.Function): ...@@ -192,12 +192,12 @@ class GSpMM(th.autograd.Function):
class GSpMM_hetero(th.autograd.Function): class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, X_len, feats) out, (argX, argY) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id? # TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations # checking the first relation to decide for all the relations
src_id, dst_id = g._graph.metagraph.find_edge(0) src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id]) reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(X_len)]) for i in range(X_len)])
...@@ -205,7 +205,7 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -205,7 +205,7 @@ class GSpMM_hetero(th.autograd.Function):
for i in range(len(Y))]) for i in range(len(Y))])
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(X_len)]) for i in range(X_len)])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
...@@ -223,14 +223,14 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -223,14 +223,14 @@ class GSpMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, *dZ): def backward(ctx, *dZ):
g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
feats = ctx.saved_tensors[:-2] feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2] argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1] argY = ctx.saved_tensors[-1]
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = g.reverse() g_rev = gidx.reverse()
# TODO(Israt): implement other combinations of message and reduce functions # TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum': if reduce_op == 'sum':
if op == 'mul': if op == 'mul':
...@@ -251,11 +251,11 @@ class GSpMM_hetero(th.autograd.Function): ...@@ -251,11 +251,11 @@ class GSpMM_hetero(th.autograd.Function):
for i in range(len(dZ))]) for i in range(len(dZ))])
tpl_X_dZ = tuple(X + tpl_dZ) tpl_X_dZ = tuple(X + tpl_dZ)
if op == 'mul' and reduce_last: if op == 'mul' and reduce_last:
dY = gsddmm_hetero(g, 'dot', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
elif op == 'mul': elif op == 'mul':
dY = gsddmm_hetero(g, 'mul', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']: elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(g, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ) dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))]) for i in range(len(Y))])
else: # Y has no gradient else: # Y has no gradient
...@@ -345,20 +345,18 @@ class GSDDMM(th.autograd.Function): ...@@ -345,20 +345,18 @@ class GSDDMM(th.autograd.Function):
class GSDDMM_hetero(th.autograd.Function): class GSDDMM_hetero(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y def forward(ctx, gidx, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, X_len, lhs_target, rhs_target, feats) out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats)
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(len(X))]) for i in range(len(X))])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))]) for i in range(len(Y))])
ctx.backward_cache = g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(len(X))]) for i in range(len(X))])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))]) for i in range(len(Y))])
lhs_id = get_typeid_by_target(g, g.canonical_etypes[0], lhs_target)
rhs_id = get_typeid_by_target(g, g.canonical_etypes[0], rhs_target)
ctx.save_for_backward(*feats) ctx.save_for_backward(*feats)
return out return out
...@@ -366,56 +364,56 @@ class GSDDMM_hetero(th.autograd.Function): ...@@ -366,56 +364,56 @@ class GSDDMM_hetero(th.autograd.Function):
@custom_bwd @custom_bwd
# TODO(Israt): Implement the complete backward operator # TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ): def backward(ctx, *dZ):
g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
feats = ctx.saved_tensors feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:] X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]): if op != 'copy_rhs' and any([x is not None for x in X]):
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_g = g if lhs_target == 'v' else g.reverse() _gidx = gidx if lhs_target == 'v' else gidx.reverse()
tpl_of_None = tuple([None] * len(X)) tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']: if op in ['add', 'copy_lhs']:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot else: # mul, dot
if rhs_target == lhs_target: if rhs_target == lhs_target:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
elif rhs_target == 'e': elif rhs_target == 'e':
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))]) for i in range(len(Y))])
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y))) dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dX = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(Y + dZ)) dX = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(Y + dZ))
else: # lhs_target == 'e' else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']: if op in ['add', 'copy_lhs']:
dX = dZ dX = dZ
else: # mul, dot else: # mul, dot
num_etype = g._graph.number_of_etypes() num_etype = gidx.number_of_etypes()
dX = gsddmm_hetero(g, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y)) dX = gsddmm_hetero(gidx, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))]) for i in range(len(X))])
else: else:
dX = tuple([None] * len(X)) dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]): if op != 'copy_lhs' and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_g = g if rhs_target == 'v' else g.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
tpl_of_None = tuple([None] * len(X)) tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']: if op in ['add', 'copy_rhs']:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot else: # mul, dot
if lhs_target == rhs_target: if lhs_target == rhs_target:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
elif lhs_target == 'e': elif lhs_target == 'e':
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))]) for i in range(len(X))])
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X))) dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
else: # rhs_target = !lhs_target else: # rhs_target = !lhs_target
dY = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(X + dZ)) dY = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(X + dZ))
else: else:
if op in ['add', 'copy_rhs']: if op in ['add', 'copy_rhs']:
dY = tuple([dZ[i] if dZ[i] is not None else None dY = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))]) for i in range(len(dZ))])
else: # mul, dot else: # mul, dot
num_etype = g._graph.number_of_etypes() num_etype = gidx.number_of_etypes()
dY = gsddmm_hetero(g, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X)) dY = gsddmm_hetero(gidx, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))]) for i in range(len(Y))])
else: else:
......
...@@ -83,7 +83,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): ...@@ -83,7 +83,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
# different dimensions, and different etypes may need different broadcasting # different dimensions, and different etypes may need different broadcasting
# dims for the same node. # dims for the same node.
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data)) lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
return gsddmm_internal_hetero(g, op, len(lhs_data), lhs_target, return gsddmm_internal_hetero(g._graph, op, len(lhs_data), lhs_target,
rhs_target, *lhs_and_rhs_tuple) rhs_target, *lhs_and_rhs_tuple)
def _gen_sddmm_func(lhs_target, rhs_target, binary_op): def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
......
...@@ -84,7 +84,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -84,7 +84,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
rhs_data = [None] * g._graph.number_of_etypes() if rhs_data is None else rhs_data rhs_data = [None] * g._graph.number_of_etypes() if rhs_data is None else rhs_data
# TODO (Israt): Call reshape func # TODO (Israt): Call reshape func
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data)) lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
ret = gspmm_internal_hetero(g, op, ret = gspmm_internal_hetero(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op, 'sum' if reduce_op == 'mean' else reduce_op,
len(lhs_data), *lhs_and_rhs_tuple) len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph # TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
......
...@@ -64,14 +64,13 @@ def to_dgl_nd_for_write(x): ...@@ -64,14 +64,13 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x) return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)
def get_typeid_by_target(g, rel, target): def get_typeid_by_target(gidx, etid, target):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'.""" """Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
srctype, _, dsttype = rel src_id, dst_id = gidx.metagraph.find_edge(etid)
etid = g.get_etype_id(rel)
if target in [0, 'u']: if target in [0, 'u']:
return g.get_ntype_id(srctype) return src_id
if target in [2, 'v']: if target in [2, 'v']:
return g.get_ntype_id(dsttype) return dst_id
return etid return etid
...@@ -190,11 +189,10 @@ def _gspmm(gidx, op, reduce_op, u, e): ...@@ -190,11 +189,10 @@ def _gspmm(gidx, op, reduce_op, u, e):
return v, (arg_u, arg_e) return v, (arg_u, arg_e)
def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple): def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface. r""" Generalized Sparse Matrix Multiplication interface.
""" """
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:] u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
gidx = g._graph
use_u = op != 'copy_rhs' use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs' use_e = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
...@@ -205,11 +203,8 @@ def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple): ...@@ -205,11 +203,8 @@ def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
list_v = [None] * gidx.number_of_ntypes() list_v = [None] * gidx.number_of_ntypes()
list_e = [None] * gidx.number_of_etypes() list_e = [None] * gidx.number_of_etypes()
for rel in g.canonical_etypes: for etid in range(gidx.number_of_etypes()):
srctype, _, dsttype = rel src_id, dst_id = gidx.metagraph.find_edge(etid)
etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
u = u_tuple[src_id] if use_u else None u = u_tuple[src_id] if use_u else None
e = e_tuple[etid] if use_e else None e = e_tuple[etid] if use_e else None
if use_u: if use_u:
...@@ -346,10 +341,9 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'): ...@@ -346,10 +341,9 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out return out
def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None): def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface. r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
""" """
gidx = g._graph
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:] lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]
use_lhs = op != 'copy_rhs' use_lhs = op != 'copy_rhs'
...@@ -358,8 +352,8 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t ...@@ -358,8 +352,8 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e): # TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features. # deal with scalar features.
expand_lhs, expand_rhs = False, False expand_lhs, expand_rhs = False, False
num_ntype = g._graph.number_of_ntypes() num_ntype = gidx.number_of_ntypes()
num_etype = g._graph.number_of_etypes() num_etype = gidx.number_of_etypes()
lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype
rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype
out_list = [None] * gidx.number_of_etypes() out_list = [None] * gidx.number_of_etypes()
...@@ -367,10 +361,9 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t ...@@ -367,10 +361,9 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
lhs_target = target_mapping[lhs_target] lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target] rhs_target = target_mapping[rhs_target]
for rel in g.canonical_etypes: for etid in range(gidx.number_of_etypes()):
etid = g.get_etype_id(rel) lhs_id = get_typeid_by_target(gidx, etid, lhs_target)
lhs_id = get_typeid_by_target(g, rel, lhs_target) rhs_id = get_typeid_by_target(gidx, etid, rhs_target)
rhs_id = get_typeid_by_target(g, rel, rhs_target)
lhs = lhs_tuple[lhs_id] lhs = lhs_tuple[lhs_id]
rhs = rhs_tuple[rhs_id] rhs = rhs_tuple[rhs_id]
if use_lhs: if use_lhs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment