Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6294677f
"src/vscode:/vscode.git/clone" did not exist on "f3fbf9bfc0c4613e93faa4500629f77fae32c3e6"
Unverified
Commit
6294677f
authored
Aug 13, 2020
by
Zihao Ye
Committed by
GitHub
Aug 13, 2020
Browse files
[hotfix] Set reduce results to all zero for nodes with zero in-degrees. (#2011)
parent
53629082
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
10 deletions
+23
-10
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+1
-1
python/dgl/ops/spmm.py
python/dgl/ops/spmm.py
+21
-5
tests/compute/test_sparse.py
tests/compute/test_sparse.py
+1
-4
No files found.
python/dgl/backend/pytorch/tensor.py
View file @
6294677f
...
...
@@ -277,7 +277,7 @@ def full_1d(length, fill_value, dtype, ctx):
return
th
.
full
((
length
,),
fill_value
,
dtype
=
dtype
,
device
=
ctx
)
def
nonzero_1d
(
input
):
x
=
th
.
nonzero
(
input
).
squeeze
()
x
=
th
.
nonzero
(
input
,
as_tuple
=
False
).
squeeze
()
return
x
if
x
.
dim
()
==
1
else
x
.
view
(
-
1
)
def
sort_1d
(
input
):
...
...
python/dgl/ops/spmm.py
View file @
6294677f
...
...
@@ -2,7 +2,7 @@
import
sys
from
..base
import
dgl_warning
from
..backend
import
gspmm
as
gspmm_internal
from
..backend
import
gspmm
as
gspmm_internal
,
backend_name
from
..
import
backend
as
F
__all__
=
[
'gspmm'
]
...
...
@@ -59,17 +59,33 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
new_rhs_shape
=
(
rhs_shape
[
0
],)
+
(
1
,)
*
rhs_pad_ndims
+
rhs_shape
[
1
:]
lhs_data
=
F
.
reshape
(
lhs_data
,
new_lhs_shape
)
rhs_data
=
F
.
reshape
(
rhs_data
,
new_rhs_shape
)
ret
=
gspmm_internal
(
g
.
_graph
,
op
,
'sum'
if
reduce_op
==
'mean'
else
reduce_op
,
lhs_data
,
rhs_data
)
# assign zero features for zero degree nodes.
deg
=
g
.
in_degrees
()
min_deg
=
F
.
as_scalar
(
F
.
min
(
deg
,
dim
=
0
))
if
min_deg
==
0
:
non_zero_nids
=
F
.
nonzero_1d
(
deg
==
0
)
if
backend_name
==
'pytorch'
:
ret
[
non_zero_nids
]
=
0.
else
:
dtype
=
F
.
dtype
(
ret
)
ctx
=
F
.
context
(
ret
)
ret
=
F
.
scatter_row
(
ret
,
non_zero_nids
,
F
.
zeros
((
len
(
non_zero_nids
),)
+
F
.
shape
(
ret
)[
1
:],
dtype
,
ctx
))
# divide in degrees for mean reducer.
if
reduce_op
==
'mean'
:
ret
=
gspmm_internal
(
g
.
_graph
,
op
,
'sum'
,
lhs_data
,
rhs_data
)
ret_shape
=
F
.
shape
(
ret
)
deg
=
g
.
in_degrees
()
if
F
.
as_scalar
(
F
.
min
(
deg
,
dim
=
0
))
==
0
:
if
min_deg
==
0
:
dgl_warning
(
'Zero-degree nodes encountered in mean reducer. Setting the mean to 0.'
)
deg
=
F
.
astype
(
F
.
clamp
(
deg
,
1
,
g
.
number_of_edges
()),
F
.
dtype
(
ret
))
deg_shape
=
(
ret_shape
[
0
],)
+
(
1
,)
*
(
len
(
ret_shape
)
-
1
)
return
ret
/
F
.
reshape
(
deg
,
deg_shape
)
else
:
return
gspmm_internal
(
g
.
_graph
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
return
ret
def
_gen_spmm_func
(
binary_op
,
reduce_op
):
...
...
tests/compute/test_sparse.py
View file @
6294677f
...
...
@@ -116,9 +116,6 @@ def test_spmm(idtype, g, shp, msg, reducer):
e
=
F
.
attach_grad
(
F
.
clone
(
he
))
with
F
.
record_grad
():
v
=
gspmm
(
g
,
msg
,
reducer
,
u
,
e
)
non_degree_indices
=
F
.
tensor
(
np
.
nonzero
(
F
.
asnumpy
(
g
.
in_degrees
())
!=
0
)[
0
])
v
=
F
.
gather_row
(
v
,
non_degree_indices
)
if
g
.
number_of_edges
()
>
0
:
F
.
backward
(
F
.
reduce_sum
(
v
))
if
msg
!=
'copy_rhs'
:
...
...
@@ -129,7 +126,7 @@ def test_spmm(idtype, g, shp, msg, reducer):
with
F
.
record_grad
():
g
.
update_all
(
udf_msg
[
msg
],
udf_reduce
[
reducer
])
if
g
.
number_of_edges
()
>
0
:
v1
=
F
.
gather_row
(
g
.
dstdata
[
'v'
]
,
non_degree_indices
)
v1
=
g
.
dstdata
[
'v'
]
assert
F
.
allclose
(
v
,
v1
)
print
(
'forward passed'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment