Commit 9d6d2e01 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Make L2 grad norm a CPU variable

parent bc81b1c1
...@@ -88,7 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -88,7 +88,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._do_not_flatten_model = do_not_flatten_model self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size() self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size self._num_groups = self._world_size // self._group_size
...@@ -333,9 +333,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -333,9 +333,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
for chunk_id in range(self._num_chunks): for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed # Since the packed format is contiguous after reductions, only one norm is needed
self._L2_grad_norm = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = torch.empty([1], device='cuda')
torch.distributed.all_reduce(self._L2_grad_norm,group=self._l2_grad_norm_pg) l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
self._L2_grad_norm.sqrt_() torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
def __launch_step_kernel(self, p, p_copy, m, v, g): def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale combined_scale = self._global_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment