Commit 1fd8943e authored by Nayan Singhal's avatar Nayan Singhal Committed by Facebook Github Bot
Browse files

Average local optimizer param after warmup and during bmuf sync

Summary: We have seen that averaging the local param instead of doing reset or broadcast after warmup improves the WER.

Reviewed By: skritika

Differential Revision: D16739278

fbshipit-source-id: 75033d2d25f9a88fd6dd325d0d9d4c856d22d947
parent 3e3fe722
...@@ -8,6 +8,7 @@ import types ...@@ -8,6 +8,7 @@ import types
import torch import torch
import torch.optim import torch.optim
import torch.distributed as dist
from . import FairseqOptimizer, register_optimizer from . import FairseqOptimizer, register_optimizer
...@@ -53,6 +54,17 @@ class FairseqAdam(FairseqOptimizer): ...@@ -53,6 +54,17 @@ class FairseqAdam(FairseqOptimizer):
'weight_decay': self.args.weight_decay, 'weight_decay': self.args.weight_decay,
} }
def average_params(self):
"""Reduce Params is only used during BMUF distributed training."""
state_dict = self.optimizer.state_dict()
total_gpus = float(dist.get_world_size())
for _, value in state_dict["state"].items():
value["exp_avg"] /= total_gpus
value["exp_avg_sq"] /= total_gpus
dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM)
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
class Adam(torch.optim.Optimizer): class Adam(torch.optim.Optimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
......
...@@ -31,6 +31,7 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -31,6 +31,7 @@ class FairseqBMUF(FairseqOptimizer):
self.warmup_iteration = self.args.warmup_iterations self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm self.use_nbm = self.args.use_nbm
self.initial_state = self._optimizer.state_dict() self.initial_state = self._optimizer.state_dict()
self.average_sync = self.args.average_sync
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -62,6 +63,12 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -62,6 +63,12 @@ class FairseqBMUF(FairseqOptimizer):
action="store_true", action="store_true",
help="Specify whether you want to use classical BM / Nesterov BM", help="Specify whether you want to use classical BM / Nesterov BM",
) )
parser.add_argument(
"--average-sync",
default=True,
action="store_true",
help="Specify whether you want to average the local momentum after each sync",
)
@property @property
def optimizer(self): def optimizer(self):
...@@ -91,34 +98,50 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -91,34 +98,50 @@ class FairseqBMUF(FairseqOptimizer):
"""Clips gradient norm.""" """Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm) return self._optimizer.clip_grad_norm(max_norm)
def average_params(self):
self._optimizer.average_params()
def _block_sync(self): def _block_sync(self):
# Update the global model using local models from all GPUs. # Update the global model using local models from all GPUs
if self._is_bmuf_iter(): # (Step-1) Calculate grad between previously synced model and
if self.block_momentum != 0: # currrent local model
self._BM_before_sync() if self.block_momentum != 0:
self._calc_grad()
# (Step-2) Average gradient from all GPUs
self._avg_grad_from_all_gpus()
self._allreduce_parameter() # (Step-3) Calculate global momentum and update the global model
if self.block_momentum != 0:
self._update_global_model()
if self.block_momentum != 0: # (Step-4) Average local optimizer params
self._BM_after_sync() if self.average_sync:
self.average_params()
def _is_warmup_end(self): def _is_warmup_end(self):
# Check whether train iterations is equal to warmup iter
if self.get_num_updates() == self.warmup_iteration: if self.get_num_updates() == self.warmup_iteration:
return True return True
return False return False
def _is_bmuf_iter(self): def _is_bmuf_iter(self):
# Check whether train iterations is equal to bmuf sync iter
if self.get_num_updates() % self.sync_iter == 0: if self.get_num_updates() % self.sync_iter == 0:
return True return True
return False return False
def _warmup_sync(self, rootRank=0): def _warmup_sync(self, root_rank=0):
# broadcast the local model to all GPUs # Broadcast the local model to all gpus
for param in self.params: for param in self.params:
dist.broadcast(param.data, src=rootRank) dist.broadcast(param.data, src=root_rank)
# Update local optimizer state
if self.average_sync:
self._optimizer.average_params()
else:
self._optimizer.load_state_dict(self.initial_state)
# Reset the local optimizer state and local bmuf related param
self._optimizer.load_state_dict(self.initial_state)
self._reset_local_data() self._reset_local_data()
def step(self, closure=None): def step(self, closure=None):
...@@ -127,7 +150,7 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -127,7 +150,7 @@ class FairseqBMUF(FairseqOptimizer):
self.set_num_updates(self.get_num_updates() + 1) self.set_num_updates(self.get_num_updates() + 1)
if self._is_warmup_end(): if self._is_warmup_end():
self._warmup_sync() self._warmup_sync()
else: elif self._is_bmuf_iter():
self._block_sync() self._block_sync()
def zero_grad(self): def zero_grad(self):
...@@ -144,61 +167,56 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -144,61 +167,56 @@ class FairseqBMUF(FairseqOptimizer):
@torch.no_grad() @torch.no_grad()
def _reset_local_data(self): def _reset_local_data(self):
"""Resetting all the BMUF specific params.""" # (Step-0) Initialize global momentum parameters and store global copy on each gpu
self.params_localprev = [torch.zeros_like(p.data) for p in self.params] self.global_params = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
self.smoothed_grads_localprev = [ self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
p.data.new_zeros(p.data.size()) for p in self.params
]
self.grads_localprev = [p.data.new_zeros(p.data.size()) for p in self.params]
# saving the global model locally for calculating gradient during bmuf sync # saving the global model locally for calculating gradient during bmuf sync
for param, copy_param in zip(self.params, self.params_localprev): for param, global_param in zip(self.params, self.global_params):
copy_param.copy_(param.data) global_param.copy_(param.data)
@torch.no_grad() @torch.no_grad()
def _BM_before_sync(self): def _calc_grad(self):
"""Calculate grad between previously synced model and currrent local model.""" # global_params is basically the global copy from the previously finished
# prev_param is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq # synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced # for the local gpu. so grad is difference between previously synced
# model and currrent local model. # model and currrent local model.
for index, (param, prev_param) in enumerate( for index, (param, global_param) in enumerate(
zip(self.params, self.params_localprev) zip(self.params, self.global_params)
): ):
self.grads_localprev[index] = prev_param - param.data self.grads[index] = global_param - param.data
def _allreduce_parameter(self): def _avg_grad_from_all_gpus(self):
"""Average gradient from all the GPUs. """
for index, param in enumerate(self.params): for index, param in enumerate(self.params):
sync_para = ( sync_para = param.data if self.block_momentum == 0 else self.grads[index]
param.data if self.block_momentum == 0 else self.grads_localprev[index]
)
sync_para /= float(dist.get_world_size()) sync_para /= float(dist.get_world_size())
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
@torch.no_grad() @torch.no_grad()
def _BM_after_sync(self): def _update_global_model(self):
for index, (param, prev_param, smoothed_grad, grad) in enumerate( for index, (param, global_param, smoothed_grad, grad) in enumerate(
zip( zip(
self.params, self.params,
self.params_localprev, self.global_params,
self.smoothed_grads_localprev, self.smoothed_grads,
# all machines would share the same value of smoothed_grad, since it is # all gpus would share the same value of smoothed_grad, since it is
# always computed on synchronized gradients. # always computed on synchronized gradients.
self.grads_localprev, self.grads,
) )
): ):
# prev_param is basically last syncrhornized parameter. though # global_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of # smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy. # smoothed_grad and hence param is globally synchronized copy.
# smoothed_grad(t)=BM * smoothed_grad(t-1) + BM_lr*grad(t) # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
smoothed_grad = smoothed_grad * self.block_momentum + grad * self.block_lr smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
param.data.copy_(prev_param - smoothed_grad) param.data.copy_(global_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before # A Nesterov momentum here is to do a partial weight update before
# calculating the gradient # calculating the gradient
if self.use_nbm: if self.use_nbm:
param.data.copy_(param.data - self.block_momentum * smoothed_grad) param.data.copy_(param.data - self.block_momentum * smoothed_grad)
# backup for the next synchronization. # backup for the next synchronization.
self.smoothed_grads_localprev[index] = smoothed_grad self.smoothed_grads[index] = smoothed_grad
prev_param.copy_(param.data) global_param.copy_(param.data)
...@@ -108,3 +108,6 @@ class FairseqOptimizer(object): ...@@ -108,3 +108,6 @@ class FairseqOptimizer(object):
if hasattr(self.optimizer, 'supports_memory_efficient_fp16'): if hasattr(self.optimizer, 'supports_memory_efficient_fp16'):
return self.optimizer.supports_memory_efficient_fp16 return self.optimizer.supports_memory_efficient_fp16
return False return False
def average_params(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment