Commit b85ff391 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add option to revert step through double buffering

parent ffed6e80
...@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True, amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0, compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1): dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -64,6 +64,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -64,6 +64,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
self._global_scale = None self._global_scale = None
...@@ -314,6 +322,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -314,6 +322,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._fp32_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda() self._fp32_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
if self._revert_method > 1:
self._fp32_backup_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._copy_to_fp32 = True self._copy_to_fp32 = True
step = None step = None
...@@ -376,21 +388,32 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -376,21 +388,32 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if undo: if undo:
fused_adam_cuda.adam_undo( if self._revert_method == 1:
self._fp32_p[group_buffer_start:group_buffer_end], fused_adam_cuda.adam_undo(
self._fp32_m[group_buffer_start:group_buffer_end], self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end], self._fp32_m[group_buffer_start:group_buffer_end],
self._flat_grads[group_shard_start:group_shard_end], self._fp32_v[group_buffer_start:group_buffer_end],
group['lr'], self._flat_grads[group_shard_start:group_shard_end],
beta1, group['lr'],
beta2, beta1,
group['eps'], beta2,
combined_scale, group['eps'],
step+1, # FIXME: Verify this should be step+1 combined_scale,
self.eps_mode, step+1, # FIXME: Verify this should be step+1
bias_correction, self.eps_mode,
group['weight_decay']) bias_correction,
group['weight_decay'])
elif self._revert_method == 2:
self._fp32_p[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_p[group_buffer_start:group_buffer_end])
self._fp32_m[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_m[group_buffer_start:group_buffer_end])
self._fp32_v[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_v[group_buffer_start:group_buffer_end])
elif self._revert_method == 3:
raise RuntimeError('revert_step debug option not implemented yet')
else: else:
if self._revert_method > 1:
self._fp32_backup_p[group_buffer_start:group_buffer_end].copy_(self._fp32_p[group_buffer_start:group_buffer_end])
self._fp32_backup_m[group_buffer_start:group_buffer_end].copy_(self._fp32_m[group_buffer_start:group_buffer_end])
self._fp32_backup_v[group_buffer_start:group_buffer_end].copy_(self._fp32_v[group_buffer_start:group_buffer_end])
fused_adam_cuda.adam( fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end], self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_shard_start:group_shard_end], self._new_params[group_shard_start:group_shard_end],
...@@ -412,7 +435,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -412,7 +435,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block in range(self._num_blocks): for block in range(self._num_blocks):
grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size] grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
shard_grad_norm = grad_shards[self._rank_in_group].float().norm() shard_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)
partial_sum += (shard_grad_norm*shard_grad_norm) partial_sum += (shard_grad_norm*shard_grad_norm)
torch.distributed.all_reduce(partial_sum,group=self._rs_pg[0], async_op=False) torch.distributed.all_reduce(partial_sum,group=self._rs_pg[0], async_op=False)
self._L2_grad_norm = partial_sum.sqrt().item() self._L2_grad_norm = partial_sum.sqrt().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment