Unverified Commit 156c17f3 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[BugFix] Fix a bug of distributed SparseAdam optimizer. (#3561)



* Fix dist user embed bug

* Add some comments

* trigger
Co-authored-by: default avatarxiangsx <xiangsx@ip-10-3-83-86.eu-west-1.compute.internal>
parent cb0e1103
......@@ -331,8 +331,18 @@ class SparseAdam(DistSparseGradOptimizer):
grad_indices, inverse, cnt = th.unique(idx, return_inverse=True, return_counts=True)
# update grad state
state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev, non_blocking=True)
# The original implementation will cause read/write contension.
# state_step[state_idx] += 1
# state_step = state_step[state_idx].to(exec_dev, non_blocking=True)
# In a distributed environment, the first line of code will send write requests to
# kvstore servers to update the state_step which is asynchronous and the second line
# of code will also send read requests to kvstore servers. The write and read requests
# may be handled by different kvstore servers managing the same portion of the
# state_step dist tensor in the same node. So that, the read request may read an old
# value (i.e., 0 in the first iteration) which will cause update_power_corr to be NaN
state_val = state_step[state_idx] + 1
state_step[state_idx] = state_val
state_step = state_val.to(exec_dev, non_blocking=True)
orig_mem = state_mem[state_idx].to(exec_dev, non_blocking=True)
orig_power = state_power[state_idx].to(exec_dev, non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment