Unverified Commit 6001001f authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

Improve usage of pinned memory in sparse_optimizer (#3207)


Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent c7935935
...@@ -626,8 +626,7 @@ class SparseAdam(SparseGradOptimizer): ...@@ -626,8 +626,7 @@ class SparseAdam(SparseGradOptimizer):
# only perform async copies cpu -> gpu, or gpu-> gpu, but block # only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished # when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu # before operating on the data on the cpu
state_nonblock = False # state_dev != th.device('cpu') state_block = state_dev == th.device('cpu') and exec_dev != state_dev
exec_nonblock = False # exec_dev != th.device('cpu')
# There can be duplicated indices due to sampling. # There can be duplicated indices due to sampling.
# Thus unique them here and average the gradient here. # Thus unique them here and average the gradient here.
...@@ -636,9 +635,9 @@ class SparseAdam(SparseGradOptimizer): ...@@ -636,9 +635,9 @@ class SparseAdam(SparseGradOptimizer):
return_counts=True) return_counts=True)
state_idx = grad_indices.to(state_dev) state_idx = grad_indices.to(state_dev)
state_step[state_idx] += 1 state_step[state_idx] += 1
state_step = state_step[state_idx].to(exec_dev, non_blocking=exec_nonblock) state_step = state_step[state_idx].to(exec_dev)
orig_mem = state_mem[state_idx].to(exec_dev, non_blocking=exec_nonblock) orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev, non_blocking=exec_nonblock) orig_power = state_power[state_idx].to(exec_dev)
grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev) grad_values = th.zeros((grad_indices.shape[0], grad.shape[1]), device=exec_dev)
grad_values.index_add_(0, inverse, grad) grad_values.index_add_(0, inverse, grad)
...@@ -646,17 +645,34 @@ class SparseAdam(SparseGradOptimizer): ...@@ -646,17 +645,34 @@ class SparseAdam(SparseGradOptimizer):
grad_mem = grad_values grad_mem = grad_values
grad_power = grad_values * grad_values grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem update_mem = beta1 * orig_mem + (1.-beta1) * grad_mem
update_power = beta2 * orig_power + (1.-beta2) * grad_power update_power = beta2 * orig_power + (1.-beta2) * grad_power
state_mem[state_idx] = update_mem.to(state_dev, update_mem_dst = update_mem.to(state_dev, non_blocking=True)
non_blocking=state_nonblock) update_power_dst = update_power.to(state_dev, non_blocking=True)
state_power[state_idx] = update_power.to(state_dev, if state_block:
non_blocking=state_nonblock) # use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()
update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev), update_mem_corr = update_mem / (1. - th.pow(th.tensor(beta1, device=exec_dev),
state_step)).unsqueeze(1) state_step)).unsqueeze(1)
update_power_corr = update_power / (1. - th.pow(th.tensor(beta2, device=exec_dev), update_power_corr = update_power / (1. - th.pow(th.tensor(beta2, device=exec_dev),
state_step)).unsqueeze(1) state_step)).unsqueeze(1)
std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps) std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
std_values_dst = std_values.to(state_dev, non_blocking=True)
emb.weight[state_idx] -= std_values.to(state_dev)
if state_block:
std_event = th.cuda.Event()
std_event.record()
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
state_mem[state_idx] = update_mem_dst
state_power[state_idx] = update_power_dst
if state_block:
# wait for the transfer of std_values to finish before we
# can use it
std_event.wait()
emb.weight[state_idx] -= std_values_dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment