"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e564abe292750b7d2eef07f2b49ea2056df391ab"
Unverified Commit ff8f7082 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] turn off recording on embeddings in the inference. (#1861)

* turn on/off recording in sparse embedding.

* add test.
parent bcb988bd
...@@ -1474,6 +1474,11 @@ def is_no_grad(x): ...@@ -1474,6 +1474,11 @@ def is_no_grad(x):
""" """
pass pass
def is_recording():
""" Test if the execution is recording gradients.
"""
pass
class record_grad(object): class record_grad(object):
"""Context manager that records the gradients""" """Context manager that records the gradients"""
def __init__(self): def __init__(self):
......
...@@ -605,6 +605,9 @@ def grad(x): ...@@ -605,6 +605,9 @@ def grad(x):
def is_no_grad(x): def is_no_grad(x):
return (x != 0).sum() == 0 return (x != 0).sum() == 0
def is_recording():
return mx.autograd.is_recording()
record_grad = mx.autograd.record record_grad = mx.autograd.record
class no_grad(object): class no_grad(object):
......
...@@ -517,6 +517,9 @@ def grad(x): ...@@ -517,6 +517,9 @@ def grad(x):
def is_no_grad(x): def is_no_grad(x):
return x.grad is None or (x.grad == 0).all() return x.grad is None or (x.grad == 0).all()
def is_recording():
return th.is_grad_enabled()
class record_grad(object): class record_grad(object):
def __init__(self): def __init__(self):
pass pass
......
...@@ -685,6 +685,9 @@ def grad(x): ...@@ -685,6 +685,9 @@ def grad(x):
def is_no_grad(x): def is_no_grad(x):
return cgrad.is_no_grad(x) return cgrad.is_no_grad(x)
def is_recording():
raise NotImplementedError("Tensorflow doesn't support is_recording")
no_grad = None no_grad = None
initialize_context() initialize_context()
...@@ -47,8 +47,10 @@ class DistEmbedding: ...@@ -47,8 +47,10 @@ class DistEmbedding:
def __call__(self, idx): def __call__(self, idx):
idx = utils.toindex(idx).tousertensor() idx = utils.toindex(idx).tousertensor()
emb = F.attach_grad(self._tensor[idx]) emb = self._tensor[idx]
self._trace.append((idx, emb)) if F.is_recording():
emb = F.attach_grad(emb)
self._trace.append((idx, emb))
return emb return emb
class SparseAdagradUDF: class SparseAdagradUDF:
......
...@@ -142,6 +142,10 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -142,6 +142,10 @@ def check_dist_graph(g, num_nodes, num_edges):
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init) emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init)
with F.no_grad():
feats1 = emb(nids)
assert np.all(F.asnumpy(feats1) == 0)
optimizer = SparseAdagrad([emb], lr=lr) optimizer = SparseAdagrad([emb], lr=lr)
with F.record_grad(): with F.record_grad():
feats1 = emb(nids) feats1 = emb(nids)
...@@ -151,7 +155,8 @@ def check_dist_graph(g, num_nodes, num_edges): ...@@ -151,7 +155,8 @@ def check_dist_graph(g, num_nodes, num_edges):
loss = F.sum(feats + 1, 0) loss = F.sum(feats + 1, 0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
feats = emb(nids) with F.no_grad():
feats = emb(nids)
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr) assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids)) rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest) feats1 = emb(rest)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment