"...pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5d2d14538a87e45891609d172d7aa05a1b756068"
Unverified Commit 89322856 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix sparse_optim when the state is stored on the CPU (fixes #2760) (#3013)



* Fix sparse optimizer to wait on copies to the CPU

* Fix linting

* Fix typo
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent b4ad59d7
...@@ -598,6 +598,12 @@ class SparseAdam(SparseGradOptimizer): ...@@ -598,6 +598,12 @@ class SparseAdam(SparseGradOptimizer):
exec_dev = grad.device exec_dev = grad.device
state_dev = state_step.device state_dev = state_step.device
# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_nonblock = state_dev != th.device('cpu')
exec_nonblock = exec_dev != th.device('cpu')
# There can be duplicated indices due to sampling. # There can be duplicated indices due to sampling.
# Thus unique them here and average the gradient here. # Thus unique them here and average the gradient here.
grad_indices, inverse, cnt = th.unique(idx, grad_indices, inverse, cnt = th.unique(idx,
...@@ -605,9 +611,9 @@ class SparseAdam(SparseGradOptimizer): ...@@ -605,9 +611,9 @@ class SparseAdam(SparseGradOptimizer):
return_counts=True) return_counts=True)
state_idx = grad_indices.to(state_dev) state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1 state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev, non_blocking=True) state_step = state_step[state_idx].to(exec_dev, non_blocking=exec_nonblock)
orig_mem = state_mem[state_idx].to(exec_dev, non_blocking=True) orig_mem = state_mem[state_idx].to(exec_dev, non_blocking=exec_nonblock)
orig_power = state_power[state_idx].to(exec_dev, non_blocking=True) orig_power = state_power[state_idx].to(exec_dev, non_blocking=exec_nonblock)
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev) grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev)
grad_values.index_add_(0, inverse, grad) grad_values.index_add_(0, inverse, grad)
...@@ -617,8 +623,10 @@ class SparseAdam(SparseGradOptimizer): ...@@ -617,8 +623,10 @@ class SparseAdam(SparseGradOptimizer):
grad_power = grad_values * grad_values grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem
update_power = beta2 * orig_power + (1.-beta2) * grad_power update_power = beta2 * orig_power + (1.-beta2) * grad_power
state_mem[state_idx] = update_mem.to(state_dev, non_blocking=True) state_mem[state_idx] = update_mem.to(state_dev,
state_power[state_idx] = update_power.to(state_dev, non_blocking=True) non_blocking=state_nonblock)
state_power[state_idx] = update_power.to(state_dev,
non_blocking=state_nonblock)
update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev), update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev),
state_step)).unsqueeze(1) state_step)).unsqueeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment