Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
532eaa87
Unverified
Commit
532eaa87
authored
Oct 11, 2021
by
Israt Nisa
Committed by
GitHub
Oct 11, 2021
Browse files
backward now stores DGLGraph index,not DGLGraph object witattached data (#3410)
Co-authored-by:
Israt Nisa
<
nisisrat@amazon.com
>
parent
aef96dfa
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
52 deletions
+43
-52
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+28
-30
python/dgl/ops/sddmm.py
python/dgl/ops/sddmm.py
+1
-1
python/dgl/ops/spmm.py
python/dgl/ops/spmm.py
+1
-1
python/dgl/sparse.py
python/dgl/sparse.py
+13
-20
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
532eaa87
...
@@ -2,7 +2,7 @@ import torch as th
...
@@ -2,7 +2,7 @@ import torch as th
from
distutils.version
import
LooseVersion
from
distutils.version
import
LooseVersion
from
...base
import
is_all
,
ALL
from
...base
import
is_all
,
ALL
from
...sparse
import
_gspmm
,
_gspmm_hetero
,
_gsddmm
,
_gsddmm_hetero
,
_segment_reduce
,
_bwd_segment_cmp
,
_scatter_add
from
...sparse
import
_gspmm
,
_gspmm_hetero
,
_gsddmm
,
_gsddmm_hetero
,
_segment_reduce
,
_bwd_segment_cmp
,
_scatter_add
from
...sparse
import
_csrmm
,
_csrsum
,
_csrmask
,
get_typeid_by_target
from
...sparse
import
_csrmm
,
_csrsum
,
_csrmask
from
...heterograph_index
import
create_unitgraph_from_csr
from
...heterograph_index
import
create_unitgraph_from_csr
if
LooseVersion
(
th
.
__version__
)
>=
LooseVersion
(
"1.6.0"
):
if
LooseVersion
(
th
.
__version__
)
>=
LooseVersion
(
"1.6.0"
):
...
@@ -192,12 +192,12 @@ class GSpMM(th.autograd.Function):
...
@@ -192,12 +192,12 @@ class GSpMM(th.autograd.Function):
class
GSpMM_hetero
(
th
.
autograd
.
Function
):
class
GSpMM_hetero
(
th
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
g
,
op
,
reduce_op
,
X_len
,
*
feats
):
# feats = lhs_data + rhs_data
def
forward
(
ctx
,
g
idx
,
op
,
reduce_op
,
X_len
,
*
feats
):
# feats = lhs_data + rhs_data
out
,
(
argX
,
argY
)
=
_gspmm_hetero
(
g
,
op
,
reduce_op
,
X_len
,
feats
)
out
,
(
argX
,
argY
)
=
_gspmm_hetero
(
g
idx
,
op
,
reduce_op
,
X_len
,
feats
)
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
# TODO (Israt): check target to decide src_id/dst_id?
# TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
# checking the first relation to decide for all the relations
src_id
,
dst_id
=
g
.
_graph
.
metagraph
.
find_edge
(
0
)
src_id
,
dst_id
=
g
idx
.
metagraph
.
find_edge
(
0
)
reduce_last
=
_need_reduce_last_dim
(
X
[
src_id
],
Y
[
dst_id
])
reduce_last
=
_need_reduce_last_dim
(
X
[
src_id
],
Y
[
dst_id
])
X_shape
=
tuple
([
X
[
i
].
shape
if
X
[
i
]
is
not
None
else
None
X_shape
=
tuple
([
X
[
i
].
shape
if
X
[
i
]
is
not
None
else
None
for
i
in
range
(
X_len
)])
for
i
in
range
(
X_len
)])
...
@@ -205,7 +205,7 @@ class GSpMM_hetero(th.autograd.Function):
...
@@ -205,7 +205,7 @@ class GSpMM_hetero(th.autograd.Function):
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
dtype
=
X
[
src_id
].
dtype
if
X
[
src_id
]
is
not
None
else
Y
[
dst_id
].
dtype
dtype
=
X
[
src_id
].
dtype
if
X
[
src_id
]
is
not
None
else
Y
[
dst_id
].
dtype
device
=
X
[
src_id
].
device
if
X
[
src_id
]
is
not
None
else
Y
[
dst_id
].
device
device
=
X
[
src_id
].
device
if
X
[
src_id
]
is
not
None
else
Y
[
dst_id
].
device
ctx
.
backward_cache
=
g
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
,
X_len
ctx
.
backward_cache
=
g
idx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
,
X_len
req_grad_X
=
tuple
([
X
[
i
].
requires_grad
if
X
[
i
]
is
not
None
else
False
req_grad_X
=
tuple
([
X
[
i
].
requires_grad
if
X
[
i
]
is
not
None
else
False
for
i
in
range
(
X_len
)])
for
i
in
range
(
X_len
)])
req_grad_Y
=
tuple
([
Y
[
i
].
requires_grad
if
Y
[
i
]
is
not
None
else
False
req_grad_Y
=
tuple
([
Y
[
i
].
requires_grad
if
Y
[
i
]
is
not
None
else
False
...
@@ -223,14 +223,14 @@ class GSpMM_hetero(th.autograd.Function):
...
@@ -223,14 +223,14 @@ class GSpMM_hetero(th.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
,
*
dZ
):
def
backward
(
ctx
,
*
dZ
):
g
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
,
X_len
=
ctx
.
backward_cache
g
idx
,
op
,
reduce_op
,
X_shape
,
Y_shape
,
dtype
,
device
,
reduce_last
,
X_len
=
ctx
.
backward_cache
feats
=
ctx
.
saved_tensors
[:
-
2
]
feats
=
ctx
.
saved_tensors
[:
-
2
]
argX
=
ctx
.
saved_tensors
[
-
2
]
argX
=
ctx
.
saved_tensors
[
-
2
]
argY
=
ctx
.
saved_tensors
[
-
1
]
argY
=
ctx
.
saved_tensors
[
-
1
]
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
if
op
!=
'copy_rhs'
and
any
([
x
is
not
None
for
x
in
X
]):
if
op
!=
'copy_rhs'
and
any
([
x
is
not
None
for
x
in
X
]):
g_rev
=
g
.
reverse
()
g_rev
=
g
idx
.
reverse
()
# TODO(Israt): implement other combinations of message and reduce functions
# TODO(Israt): implement other combinations of message and reduce functions
if
reduce_op
==
'sum'
:
if
reduce_op
==
'sum'
:
if
op
==
'mul'
:
if
op
==
'mul'
:
...
@@ -251,11 +251,11 @@ class GSpMM_hetero(th.autograd.Function):
...
@@ -251,11 +251,11 @@ class GSpMM_hetero(th.autograd.Function):
for
i
in
range
(
len
(
dZ
))])
for
i
in
range
(
len
(
dZ
))])
tpl_X_dZ
=
tuple
(
X
+
tpl_dZ
)
tpl_X_dZ
=
tuple
(
X
+
tpl_dZ
)
if
op
==
'mul'
and
reduce_last
:
if
op
==
'mul'
and
reduce_last
:
dY
=
gsddmm_hetero
(
g
,
'dot'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
dY
=
gsddmm_hetero
(
g
idx
,
'dot'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
elif
op
==
'mul'
:
elif
op
==
'mul'
:
dY
=
gsddmm_hetero
(
g
,
'mul'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
dY
=
gsddmm_hetero
(
g
idx
,
'mul'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
elif
op
in
[
'add'
,
'copy_rhs'
]:
elif
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
gsddmm_hetero
(
g
,
'copy_rhs'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
dY
=
gsddmm_hetero
(
g
idx
,
'copy_rhs'
,
X_len
,
'u'
,
'v'
,
*
tpl_X_dZ
)
dY
=
tuple
([
_reduce_grad
(
dY
[
i
],
Y_shape
[
i
])
if
Y
[
i
]
is
not
None
else
None
dY
=
tuple
([
_reduce_grad
(
dY
[
i
],
Y_shape
[
i
])
if
Y
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
else
:
# Y has no gradient
else
:
# Y has no gradient
...
@@ -345,20 +345,18 @@ class GSDDMM(th.autograd.Function):
...
@@ -345,20 +345,18 @@ class GSDDMM(th.autograd.Function):
class
GSDDMM_hetero
(
th
.
autograd
.
Function
):
class
GSDDMM_hetero
(
th
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
@
custom_fwd
(
cast_inputs
=
th
.
float16
)
def
forward
(
ctx
,
g
,
op
,
X_len
,
lhs_target
,
rhs_target
,
*
feats
):
# feats = X+Y
def
forward
(
ctx
,
g
idx
,
op
,
X_len
,
lhs_target
,
rhs_target
,
*
feats
):
# feats = X+Y
out
=
_gsddmm_hetero
(
g
,
op
,
X_len
,
lhs_target
,
rhs_target
,
feats
)
out
=
_gsddmm_hetero
(
g
idx
,
op
,
X_len
,
lhs_target
,
rhs_target
,
feats
)
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
X_shape
=
tuple
([
X
[
i
].
shape
if
X
[
i
]
is
not
None
else
None
X_shape
=
tuple
([
X
[
i
].
shape
if
X
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
X
))])
for
i
in
range
(
len
(
X
))])
Y_shape
=
tuple
([
Y
[
i
].
shape
if
Y
[
i
]
is
not
None
else
None
Y_shape
=
tuple
([
Y
[
i
].
shape
if
Y
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
ctx
.
backward_cache
=
g
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
,
X_len
ctx
.
backward_cache
=
g
idx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
,
X_len
req_grad_X
=
tuple
([
X
[
i
].
requires_grad
if
X
[
i
]
is
not
None
else
False
req_grad_X
=
tuple
([
X
[
i
].
requires_grad
if
X
[
i
]
is
not
None
else
False
for
i
in
range
(
len
(
X
))])
for
i
in
range
(
len
(
X
))])
req_grad_Y
=
tuple
([
Y
[
i
].
requires_grad
if
Y
[
i
]
is
not
None
else
False
req_grad_Y
=
tuple
([
Y
[
i
].
requires_grad
if
Y
[
i
]
is
not
None
else
False
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
lhs_id
=
get_typeid_by_target
(
g
,
g
.
canonical_etypes
[
0
],
lhs_target
)
rhs_id
=
get_typeid_by_target
(
g
,
g
.
canonical_etypes
[
0
],
rhs_target
)
ctx
.
save_for_backward
(
*
feats
)
ctx
.
save_for_backward
(
*
feats
)
return
out
return
out
...
@@ -366,56 +364,56 @@ class GSDDMM_hetero(th.autograd.Function):
...
@@ -366,56 +364,56 @@ class GSDDMM_hetero(th.autograd.Function):
@
custom_bwd
@
custom_bwd
# TODO(Israt): Implement the complete backward operator
# TODO(Israt): Implement the complete backward operator
def
backward
(
ctx
,
*
dZ
):
def
backward
(
ctx
,
*
dZ
):
g
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
,
X_len
=
ctx
.
backward_cache
g
idx
,
op
,
lhs_target
,
rhs_target
,
X_shape
,
Y_shape
,
X_len
=
ctx
.
backward_cache
feats
=
ctx
.
saved_tensors
feats
=
ctx
.
saved_tensors
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
X
,
Y
=
feats
[:
X_len
],
feats
[
X_len
:]
if
op
!=
'copy_rhs'
and
any
([
x
is
not
None
for
x
in
X
]):
if
op
!=
'copy_rhs'
and
any
([
x
is
not
None
for
x
in
X
]):
if
lhs_target
in
[
'u'
,
'v'
]:
if
lhs_target
in
[
'u'
,
'v'
]:
_g
=
g
if
lhs_target
==
'v'
else
g
.
reverse
()
_g
idx
=
g
idx
if
lhs_target
==
'v'
else
g
idx
.
reverse
()
tpl_of_None
=
tuple
([
None
]
*
len
(
X
))
tpl_of_None
=
tuple
([
None
]
*
len
(
X
))
if
op
in
[
'add'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
dX
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
else
:
# mul, dot
else
:
# mul, dot
if
rhs_target
==
lhs_target
:
if
rhs_target
==
lhs_target
:
dX
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
*
Y
dX
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
*
Y
elif
rhs_target
==
'e'
:
elif
rhs_target
==
'e'
:
dZ_mul_Y
=
tuple
([
dZ
[
i
]
*
Y
[
i
]
if
dZ
[
i
]
is
not
None
else
None
dZ_mul_Y
=
tuple
([
dZ
[
i
]
*
Y
[
i
]
if
dZ
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
dX
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ_mul_Y
)))
dX
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ_mul_Y
)))
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dX
=
gspmm_hetero
(
_g
,
'mul'
,
'sum'
,
len
(
X
),
*
tuple
(
Y
+
dZ
))
dX
=
gspmm_hetero
(
_g
idx
,
'mul'
,
'sum'
,
len
(
X
),
*
tuple
(
Y
+
dZ
))
else
:
# lhs_target == 'e'
else
:
# lhs_target == 'e'
if
op
in
[
'add'
,
'copy_lhs'
]:
if
op
in
[
'add'
,
'copy_lhs'
]:
dX
=
dZ
dX
=
dZ
else
:
# mul, dot
else
:
# mul, dot
num_etype
=
g
.
_graph
.
number_of_etypes
()
num_etype
=
g
idx
.
number_of_etypes
()
dX
=
gsddmm_hetero
(
g
,
'mul'
,
num_etype
,
'e'
,
rhs_target
,
*
tuple
(
dZ
+
Y
))
dX
=
gsddmm_hetero
(
g
idx
,
'mul'
,
num_etype
,
'e'
,
rhs_target
,
*
tuple
(
dZ
+
Y
))
dX
=
tuple
([
_reduce_grad
(
dX
[
i
],
X_shape
[
i
])
if
X
[
i
]
is
not
None
else
None
dX
=
tuple
([
_reduce_grad
(
dX
[
i
],
X_shape
[
i
])
if
X
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
X
))])
for
i
in
range
(
len
(
X
))])
else
:
else
:
dX
=
tuple
([
None
]
*
len
(
X
))
dX
=
tuple
([
None
]
*
len
(
X
))
if
op
!=
'copy_lhs'
and
any
([
y
is
not
None
for
y
in
Y
]):
if
op
!=
'copy_lhs'
and
any
([
y
is
not
None
for
y
in
Y
]):
if
rhs_target
in
[
'u'
,
'v'
]:
if
rhs_target
in
[
'u'
,
'v'
]:
_g
=
g
if
rhs_target
==
'v'
else
g
.
reverse
()
_g
idx
=
g
idx
if
rhs_target
==
'v'
else
g
idx
.
reverse
()
tpl_of_None
=
tuple
([
None
]
*
len
(
X
))
tpl_of_None
=
tuple
([
None
]
*
len
(
X
))
if
op
in
[
'add'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
dY
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
else
:
# mul, dot
else
:
# mul, dot
if
lhs_target
==
rhs_target
:
if
lhs_target
==
rhs_target
:
dY
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
*
X
dY
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ
)))
*
X
elif
lhs_target
==
'e'
:
elif
lhs_target
==
'e'
:
dZ_mul_X
=
tuple
([
dZ
[
i
]
*
X
[
i
]
if
dZ
[
i
]
is
not
None
else
None
dZ_mul_X
=
tuple
([
dZ
[
i
]
*
X
[
i
]
if
dZ
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
X
))])
for
i
in
range
(
len
(
X
))])
dY
=
gspmm_hetero
(
_g
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ_mul_X
)))
dY
=
gspmm_hetero
(
_g
idx
,
'copy_rhs'
,
'sum'
,
len
(
X
),
*
(
tuple
(
tpl_of_None
+
dZ_mul_X
)))
else
:
# rhs_target = !lhs_target
else
:
# rhs_target = !lhs_target
dY
=
gspmm_hetero
(
_g
,
'mul'
,
'sum'
,
len
(
X
),
*
tuple
(
X
+
dZ
))
dY
=
gspmm_hetero
(
_g
idx
,
'mul'
,
'sum'
,
len
(
X
),
*
tuple
(
X
+
dZ
))
else
:
else
:
if
op
in
[
'add'
,
'copy_rhs'
]:
if
op
in
[
'add'
,
'copy_rhs'
]:
dY
=
tuple
([
dZ
[
i
]
if
dZ
[
i
]
is
not
None
else
None
dY
=
tuple
([
dZ
[
i
]
if
dZ
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
dZ
))])
for
i
in
range
(
len
(
dZ
))])
else
:
# mul, dot
else
:
# mul, dot
num_etype
=
g
.
_graph
.
number_of_etypes
()
num_etype
=
g
idx
.
number_of_etypes
()
dY
=
gsddmm_hetero
(
g
,
'mul'
,
num_etype
,
'e'
,
lhs_target
,
*
tuple
(
dZ
+
X
))
dY
=
gsddmm_hetero
(
g
idx
,
'mul'
,
num_etype
,
'e'
,
lhs_target
,
*
tuple
(
dZ
+
X
))
dY
=
tuple
([
_reduce_grad
(
dY
[
i
],
Y_shape
[
i
])
if
Y
[
i
]
is
not
None
else
None
dY
=
tuple
([
_reduce_grad
(
dY
[
i
],
Y_shape
[
i
])
if
Y
[
i
]
is
not
None
else
None
for
i
in
range
(
len
(
Y
))])
for
i
in
range
(
len
(
Y
))])
else
:
else
:
...
...
python/dgl/ops/sddmm.py
View file @
532eaa87
...
@@ -83,7 +83,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
...
@@ -83,7 +83,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
# different dimensions, and different etypes may need different broadcasting
# different dimensions, and different etypes may need different broadcasting
# dims for the same node.
# dims for the same node.
lhs_and_rhs_tuple
=
tuple
(
list
(
lhs_data
)
+
list
(
rhs_data
))
lhs_and_rhs_tuple
=
tuple
(
list
(
lhs_data
)
+
list
(
rhs_data
))
return
gsddmm_internal_hetero
(
g
,
op
,
len
(
lhs_data
),
lhs_target
,
return
gsddmm_internal_hetero
(
g
.
_graph
,
op
,
len
(
lhs_data
),
lhs_target
,
rhs_target
,
*
lhs_and_rhs_tuple
)
rhs_target
,
*
lhs_and_rhs_tuple
)
def
_gen_sddmm_func
(
lhs_target
,
rhs_target
,
binary_op
):
def
_gen_sddmm_func
(
lhs_target
,
rhs_target
,
binary_op
):
...
...
python/dgl/ops/spmm.py
View file @
532eaa87
...
@@ -84,7 +84,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
...
@@ -84,7 +84,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
rhs_data
=
[
None
]
*
g
.
_graph
.
number_of_etypes
()
if
rhs_data
is
None
else
rhs_data
rhs_data
=
[
None
]
*
g
.
_graph
.
number_of_etypes
()
if
rhs_data
is
None
else
rhs_data
# TODO (Israt): Call reshape func
# TODO (Israt): Call reshape func
lhs_and_rhs_tuple
=
tuple
(
list
(
lhs_data
)
+
list
(
rhs_data
))
lhs_and_rhs_tuple
=
tuple
(
list
(
lhs_data
)
+
list
(
rhs_data
))
ret
=
gspmm_internal_hetero
(
g
,
op
,
ret
=
gspmm_internal_hetero
(
g
.
_graph
,
op
,
'sum'
if
reduce_op
==
'mean'
else
reduce_op
,
'sum'
if
reduce_op
==
'mean'
else
reduce_op
,
len
(
lhs_data
),
*
lhs_and_rhs_tuple
)
len
(
lhs_data
),
*
lhs_and_rhs_tuple
)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
...
...
python/dgl/sparse.py
View file @
532eaa87
...
@@ -64,14 +64,13 @@ def to_dgl_nd_for_write(x):
...
@@ -64,14 +64,13 @@ def to_dgl_nd_for_write(x):
return
nd
.
NULL
[
'int64'
]
if
x
is
None
else
F
.
zerocopy_to_dgl_ndarray_for_write
(
x
)
return
nd
.
NULL
[
'int64'
]
if
x
is
None
else
F
.
zerocopy_to_dgl_ndarray_for_write
(
x
)
def
get_typeid_by_target
(
g
,
rel
,
target
):
def
get_typeid_by_target
(
g
idx
,
etid
,
target
):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
srctype
,
_
,
dsttype
=
rel
src_id
,
dst_id
=
gidx
.
metagraph
.
find_edge
(
etid
)
etid
=
g
.
get_etype_id
(
rel
)
if
target
in
[
0
,
'u'
]:
if
target
in
[
0
,
'u'
]:
return
g
.
get_ntype_id
(
srctype
)
return
src_id
if
target
in
[
2
,
'v'
]:
if
target
in
[
2
,
'v'
]:
return
g
.
get_ntype_id
(
dsttype
)
return
dst_id
return
etid
return
etid
...
@@ -190,11 +189,10 @@ def _gspmm(gidx, op, reduce_op, u, e):
...
@@ -190,11 +189,10 @@ def _gspmm(gidx, op, reduce_op, u, e):
return
v
,
(
arg_u
,
arg_e
)
return
v
,
(
arg_u
,
arg_e
)
def
_gspmm_hetero
(
g
,
op
,
reduce_op
,
u_len
,
u_and_e_tuple
):
def
_gspmm_hetero
(
g
idx
,
op
,
reduce_op
,
u_len
,
u_and_e_tuple
):
r
""" Generalized Sparse Matrix Multiplication interface.
r
""" Generalized Sparse Matrix Multiplication interface.
"""
"""
u_tuple
,
e_tuple
=
u_and_e_tuple
[:
u_len
],
u_and_e_tuple
[
u_len
:]
u_tuple
,
e_tuple
=
u_and_e_tuple
[:
u_len
],
u_and_e_tuple
[
u_len
:]
gidx
=
g
.
_graph
use_u
=
op
!=
'copy_rhs'
use_u
=
op
!=
'copy_rhs'
use_e
=
op
!=
'copy_lhs'
use_e
=
op
!=
'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
...
@@ -205,11 +203,8 @@ def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
...
@@ -205,11 +203,8 @@ def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
list_v
=
[
None
]
*
gidx
.
number_of_ntypes
()
list_v
=
[
None
]
*
gidx
.
number_of_ntypes
()
list_e
=
[
None
]
*
gidx
.
number_of_etypes
()
list_e
=
[
None
]
*
gidx
.
number_of_etypes
()
for
rel
in
g
.
canonical_etypes
:
for
etid
in
range
(
gidx
.
number_of_etypes
()):
srctype
,
_
,
dsttype
=
rel
src_id
,
dst_id
=
gidx
.
metagraph
.
find_edge
(
etid
)
etid
=
g
.
get_etype_id
(
rel
)
src_id
=
g
.
get_ntype_id
(
srctype
)
dst_id
=
g
.
get_ntype_id
(
dsttype
)
u
=
u_tuple
[
src_id
]
if
use_u
else
None
u
=
u_tuple
[
src_id
]
if
use_u
else
None
e
=
e_tuple
[
etid
]
if
use_e
else
None
e
=
e_tuple
[
etid
]
if
use_e
else
None
if
use_u
:
if
use_u
:
...
@@ -346,10 +341,9 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
...
@@ -346,10 +341,9 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return
out
return
out
def
_gsddmm_hetero
(
g
,
op
,
lhs_len
,
lhs_target
=
'u'
,
rhs_target
=
'v'
,
lhs_and_rhs_tuple
=
None
):
def
_gsddmm_hetero
(
g
idx
,
op
,
lhs_len
,
lhs_target
=
'u'
,
rhs_target
=
'v'
,
lhs_and_rhs_tuple
=
None
):
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
r
""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
"""
"""
gidx
=
g
.
_graph
lhs_tuple
,
rhs_tuple
=
lhs_and_rhs_tuple
[:
lhs_len
],
lhs_and_rhs_tuple
[
lhs_len
:]
lhs_tuple
,
rhs_tuple
=
lhs_and_rhs_tuple
[:
lhs_len
],
lhs_and_rhs_tuple
[
lhs_len
:]
use_lhs
=
op
!=
'copy_rhs'
use_lhs
=
op
!=
'copy_rhs'
...
@@ -358,8 +352,8 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
...
@@ -358,8 +352,8 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
# deal with scalar features.
expand_lhs
,
expand_rhs
=
False
,
False
expand_lhs
,
expand_rhs
=
False
,
False
num_ntype
=
g
.
_graph
.
number_of_ntypes
()
num_ntype
=
g
idx
.
number_of_ntypes
()
num_etype
=
g
.
_graph
.
number_of_etypes
()
num_etype
=
g
idx
.
number_of_etypes
()
lhs_list
=
[
None
]
*
num_ntype
if
lhs_target
in
[
'u'
,
'v'
]
else
[
None
]
*
num_etype
lhs_list
=
[
None
]
*
num_ntype
if
lhs_target
in
[
'u'
,
'v'
]
else
[
None
]
*
num_etype
rhs_list
=
[
None
]
*
num_ntype
if
rhs_target
in
[
'u'
,
'v'
]
else
[
None
]
*
num_etype
rhs_list
=
[
None
]
*
num_ntype
if
rhs_target
in
[
'u'
,
'v'
]
else
[
None
]
*
num_etype
out_list
=
[
None
]
*
gidx
.
number_of_etypes
()
out_list
=
[
None
]
*
gidx
.
number_of_etypes
()
...
@@ -367,10 +361,9 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
...
@@ -367,10 +361,9 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
lhs_target
=
target_mapping
[
lhs_target
]
lhs_target
=
target_mapping
[
lhs_target
]
rhs_target
=
target_mapping
[
rhs_target
]
rhs_target
=
target_mapping
[
rhs_target
]
for
rel
in
g
.
canonical_etypes
:
for
etid
in
range
(
gidx
.
number_of_etypes
()):
etid
=
g
.
get_etype_id
(
rel
)
lhs_id
=
get_typeid_by_target
(
gidx
,
etid
,
lhs_target
)
lhs_id
=
get_typeid_by_target
(
g
,
rel
,
lhs_target
)
rhs_id
=
get_typeid_by_target
(
gidx
,
etid
,
rhs_target
)
rhs_id
=
get_typeid_by_target
(
g
,
rel
,
rhs_target
)
lhs
=
lhs_tuple
[
lhs_id
]
lhs
=
lhs_tuple
[
lhs_id
]
rhs
=
rhs_tuple
[
rhs_id
]
rhs
=
rhs_tuple
[
rhs_id
]
if
use_lhs
:
if
use_lhs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment