Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ff8f7082
Unverified
Commit
ff8f7082
authored
Jul 27, 2020
by
Da Zheng
Committed by
GitHub
Jul 27, 2020
Browse files
[Distributed] turn off recording on embeddings in the inference. (#1861)
* turn on/off recording in sparse embedding. * add test.
parent
bcb988bd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
24 additions
and
3 deletions
+24
-3
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+5
-0
python/dgl/backend/mxnet/tensor.py
python/dgl/backend/mxnet/tensor.py
+3
-0
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+3
-0
python/dgl/backend/tensorflow/tensor.py
python/dgl/backend/tensorflow/tensor.py
+3
-0
python/dgl/distributed/sparse_emb.py
python/dgl/distributed/sparse_emb.py
+4
-2
tests/distributed/test_dist_graph_store.py
tests/distributed/test_dist_graph_store.py
+6
-1
No files found.
python/dgl/backend/backend.py
View file @
ff8f7082
...
...
@@ -1474,6 +1474,11 @@ def is_no_grad(x):
"""
pass
def
is_recording
():
""" Test if the execution is recording gradients.
"""
pass
class
record_grad
(
object
):
"""Context manager that records the gradients"""
def
__init__
(
self
):
...
...
python/dgl/backend/mxnet/tensor.py
View file @
ff8f7082
...
...
@@ -605,6 +605,9 @@ def grad(x):
def
is_no_grad
(
x
):
return
(
x
!=
0
).
sum
()
==
0
def
is_recording
():
return
mx
.
autograd
.
is_recording
()
record_grad
=
mx
.
autograd
.
record
class
no_grad
(
object
):
...
...
python/dgl/backend/pytorch/tensor.py
View file @
ff8f7082
...
...
@@ -517,6 +517,9 @@ def grad(x):
def
is_no_grad
(
x
):
return
x
.
grad
is
None
or
(
x
.
grad
==
0
).
all
()
def
is_recording
():
return
th
.
is_grad_enabled
()
class
record_grad
(
object
):
def
__init__
(
self
):
pass
...
...
python/dgl/backend/tensorflow/tensor.py
View file @
ff8f7082
...
...
@@ -685,6 +685,9 @@ def grad(x):
def
is_no_grad
(
x
):
return
cgrad
.
is_no_grad
(
x
)
def
is_recording
():
raise
NotImplementedError
(
"Tensorflow doesn't support is_recording"
)
no_grad
=
None
initialize_context
()
python/dgl/distributed/sparse_emb.py
View file @
ff8f7082
...
...
@@ -47,7 +47,9 @@ class DistEmbedding:
def
__call__
(
self
,
idx
):
idx
=
utils
.
toindex
(
idx
).
tousertensor
()
emb
=
F
.
attach_grad
(
self
.
_tensor
[
idx
])
emb
=
self
.
_tensor
[
idx
]
if
F
.
is_recording
():
emb
=
F
.
attach_grad
(
emb
)
self
.
_trace
.
append
((
idx
,
emb
))
return
emb
...
...
tests/distributed/test_dist_graph_store.py
View file @
ff8f7082
...
...
@@ -142,6 +142,10 @@ def check_dist_graph(g, num_nodes, num_edges):
assert
np
.
all
(
F
.
asnumpy
(
grad_sum
[
rest
])
==
np
.
zeros
((
len
(
rest
),
1
)))
emb
=
DistEmbedding
(
g
,
g
.
number_of_nodes
(),
1
,
'emb2'
,
emb_init
)
with
F
.
no_grad
():
feats1
=
emb
(
nids
)
assert
np
.
all
(
F
.
asnumpy
(
feats1
)
==
0
)
optimizer
=
SparseAdagrad
([
emb
],
lr
=
lr
)
with
F
.
record_grad
():
feats1
=
emb
(
nids
)
...
...
@@ -151,6 +155,7 @@ def check_dist_graph(g, num_nodes, num_edges):
loss
=
F
.
sum
(
feats
+
1
,
0
)
loss
.
backward
()
optimizer
.
step
()
with
F
.
no_grad
():
feats
=
emb
(
nids
)
assert_almost_equal
(
F
.
asnumpy
(
feats
),
np
.
ones
((
len
(
nids
),
1
))
*
math
.
sqrt
(
2
)
*
-
lr
)
rest
=
np
.
setdiff1d
(
np
.
arange
(
g
.
number_of_nodes
()),
F
.
asnumpy
(
nids
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment