Commit b85ff391 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add option to revert step through double buffering

parent ffed6e80
......@@ -45,7 +45,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1):
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -64,6 +64,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0])
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
......@@ -314,6 +322,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._fp32_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
if self._revert_method > 1:
self._fp32_backup_p = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_m = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._fp32_backup_v = torch.zeros([self._num_blocks*self._shard_size]).float().cuda()
self._copy_to_fp32 = True
step = None
......@@ -376,6 +388,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas']
if undo:
if self._revert_method == 1:
fused_adam_cuda.adam_undo(
self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end],
......@@ -390,7 +403,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self.eps_mode,
bias_correction,
group['weight_decay'])
elif self._revert_method == 2:
self._fp32_p[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_p[group_buffer_start:group_buffer_end])
self._fp32_m[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_m[group_buffer_start:group_buffer_end])
self._fp32_v[group_buffer_start:group_buffer_end].copy_(self._fp32_backup_v[group_buffer_start:group_buffer_end])
elif self._revert_method == 3:
raise RuntimeError('revert_step debug option not implemented yet')
else:
if self._revert_method > 1:
self._fp32_backup_p[group_buffer_start:group_buffer_end].copy_(self._fp32_p[group_buffer_start:group_buffer_end])
self._fp32_backup_m[group_buffer_start:group_buffer_end].copy_(self._fp32_m[group_buffer_start:group_buffer_end])
self._fp32_backup_v[group_buffer_start:group_buffer_end].copy_(self._fp32_v[group_buffer_start:group_buffer_end])
fused_adam_cuda.adam(
self._fp32_p[group_buffer_start:group_buffer_end],
self._new_params[group_shard_start:group_shard_end],
......@@ -412,7 +435,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for block in range(self._num_blocks):
grad_block = self._flat_grads[block*self._block_size:(block+1)*self._block_size]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
shard_grad_norm = grad_shards[self._rank_in_group].float().norm()
shard_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)
partial_sum += (shard_grad_norm*shard_grad_norm)
torch.distributed.all_reduce(partial_sum,group=self._rs_pg[0], async_op=False)
self._L2_grad_norm = partial_sum.sqrt().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment