"vscode:/vscode.git/clone" did not exist on "c4211f665ca78e6ee526edc14ee0b3547f1c94a3"
Commit 1fd8943e authored by Nayan Singhal's avatar Nayan Singhal Committed by Facebook Github Bot
Browse files

Average local optimizer param after warmup and during bmuf sync

Summary: We have seen that averaging the local param instead of doing reset or broadcast after warmup improves the WER.

Reviewed By: skritika

Differential Revision: D16739278

fbshipit-source-id: 75033d2d25f9a88fd6dd325d0d9d4c856d22d947
parent 3e3fe722
...@@ -8,6 +8,7 @@ import types ...@@ -8,6 +8,7 @@ import types
import torch import torch
import torch.optim import torch.optim
import torch.distributed as dist
from . import FairseqOptimizer, register_optimizer from . import FairseqOptimizer, register_optimizer
...@@ -53,6 +54,17 @@ class FairseqAdam(FairseqOptimizer): ...@@ -53,6 +54,17 @@ class FairseqAdam(FairseqOptimizer):
'weight_decay': self.args.weight_decay, 'weight_decay': self.args.weight_decay,
} }
def average_params(self):
"""Reduce Params is only used during BMUF distributed training."""
state_dict = self.optimizer.state_dict()
total_gpus = float(dist.get_world_size())
for _, value in state_dict["state"].items():
value["exp_avg"] /= total_gpus
value["exp_avg_sq"] /= total_gpus
dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM)
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)
class Adam(torch.optim.Optimizer): class Adam(torch.optim.Optimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
......
...@@ -31,6 +31,7 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -31,6 +31,7 @@ class FairseqBMUF(FairseqOptimizer):
self.warmup_iteration = self.args.warmup_iterations self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm self.use_nbm = self.args.use_nbm
self.initial_state = self._optimizer.state_dict() self.initial_state = self._optimizer.state_dict()
self.average_sync = self.args.average_sync
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -62,6 +63,12 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -62,6 +63,12 @@ class FairseqBMUF(FairseqOptimizer):
action="store_true", action="store_true",
help="Specify whether you want to use classical BM / Nesterov BM", help="Specify whether you want to use classical BM / Nesterov BM",
) )
parser.add_argument(
"--average-sync",
default=True,
action="store_true",
help="Specify whether you want to average the local momentum after each sync",
)
@property @property
def optimizer(self): def optimizer(self):
...@@ -91,34 +98,50 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -91,34 +98,50 @@ class FairseqBMUF(FairseqOptimizer):
"""Clips gradient norm.""" """Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm) return self._optimizer.clip_grad_norm(max_norm)
def average_params(self):
self._optimizer.average_params()
def _block_sync(self): def _block_sync(self):
# Update the global model using local models from all GPUs. # Update the global model using local models from all GPUs
if self._is_bmuf_iter(): # (Step-1) Calculate grad between previously synced model and
# currrent local model
if self.block_momentum != 0: if self.block_momentum != 0:
self._BM_before_sync() self._calc_grad()
self._allreduce_parameter() # (Step-2) Average gradient from all GPUs
self._avg_grad_from_all_gpus()
# (Step-3) Calculate global momentum and update the global model
if self.block_momentum != 0: if self.block_momentum != 0:
self._BM_after_sync() self._update_global_model()
# (Step-4) Average local optimizer params
if self.average_sync:
self.average_params()
def _is_warmup_end(self): def _is_warmup_end(self):
# Check whether train iterations is equal to warmup iter
if self.get_num_updates() == self.warmup_iteration: if self.get_num_updates() == self.warmup_iteration:
return True return True
return False return False
def _is_bmuf_iter(self): def _is_bmuf_iter(self):
# Check whether train iterations is equal to bmuf sync iter
if self.get_num_updates() % self.sync_iter == 0: if self.get_num_updates() % self.sync_iter == 0:
return True return True
return False return False
def _warmup_sync(self, rootRank=0): def _warmup_sync(self, root_rank=0):
# broadcast the local model to all GPUs # Broadcast the local model to all gpus
for param in self.params: for param in self.params:
dist.broadcast(param.data, src=rootRank) dist.broadcast(param.data, src=root_rank)
# Reset the local optimizer state and local bmuf related param # Update local optimizer state
if self.average_sync:
self._optimizer.average_params()
else:
self._optimizer.load_state_dict(self.initial_state) self._optimizer.load_state_dict(self.initial_state)
self._reset_local_data() self._reset_local_data()
def step(self, closure=None): def step(self, closure=None):
...@@ -127,7 +150,7 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -127,7 +150,7 @@ class FairseqBMUF(FairseqOptimizer):
self.set_num_updates(self.get_num_updates() + 1) self.set_num_updates(self.get_num_updates() + 1)
if self._is_warmup_end(): if self._is_warmup_end():
self._warmup_sync() self._warmup_sync()
else: elif self._is_bmuf_iter():
self._block_sync() self._block_sync()
def zero_grad(self): def zero_grad(self):
...@@ -144,61 +167,56 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -144,61 +167,56 @@ class FairseqBMUF(FairseqOptimizer):
@torch.no_grad() @torch.no_grad()
def _reset_local_data(self): def _reset_local_data(self):
"""Resetting all the BMUF specific params.""" # (Step-0) Initialize global momentum parameters and store global copy on each gpu
self.params_localprev = [torch.zeros_like(p.data) for p in self.params] self.global_params = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params]
self.smoothed_grads_localprev = [ self.grads = [p.data.new_zeros(p.data.size()) for p in self.params]
p.data.new_zeros(p.data.size()) for p in self.params
]
self.grads_localprev = [p.data.new_zeros(p.data.size()) for p in self.params]
# saving the global model locally for calculating gradient during bmuf sync # saving the global model locally for calculating gradient during bmuf sync
for param, copy_param in zip(self.params, self.params_localprev): for param, global_param in zip(self.params, self.global_params):
copy_param.copy_(param.data) global_param.copy_(param.data)
@torch.no_grad() @torch.no_grad()
def _BM_before_sync(self): def _calc_grad(self):
"""Calculate grad between previously synced model and currrent local model.""" # global_params is basically the global copy from the previously finished
# prev_param is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq # synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced # for the local gpu. so grad is difference between previously synced
# model and currrent local model. # model and currrent local model.
for index, (param, prev_param) in enumerate( for index, (param, global_param) in enumerate(
zip(self.params, self.params_localprev) zip(self.params, self.global_params)
): ):
self.grads_localprev[index] = prev_param - param.data self.grads[index] = global_param - param.data
def _allreduce_parameter(self): def _avg_grad_from_all_gpus(self):
"""Average gradient from all the GPUs. """
for index, param in enumerate(self.params): for index, param in enumerate(self.params):
sync_para = ( sync_para = param.data if self.block_momentum == 0 else self.grads[index]
param.data if self.block_momentum == 0 else self.grads_localprev[index]
)
sync_para /= float(dist.get_world_size()) sync_para /= float(dist.get_world_size())
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
@torch.no_grad() @torch.no_grad()
def _BM_after_sync(self): def _update_global_model(self):
for index, (param, prev_param, smoothed_grad, grad) in enumerate( for index, (param, global_param, smoothed_grad, grad) in enumerate(
zip( zip(
self.params, self.params,
self.params_localprev, self.global_params,
self.smoothed_grads_localprev, self.smoothed_grads,
# all machines would share the same value of smoothed_grad, since it is # all gpus would share the same value of smoothed_grad, since it is
# always computed on synchronized gradients. # always computed on synchronized gradients.
self.grads_localprev, self.grads,
) )
): ):
# prev_param is basically last syncrhornized parameter. though # global_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of # smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy. # smoothed_grad and hence param is globally synchronized copy.
# smoothed_grad(t)=BM * smoothed_grad(t-1) + BM_lr*grad(t) # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t)
smoothed_grad = smoothed_grad * self.block_momentum + grad * self.block_lr smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad
param.data.copy_(prev_param - smoothed_grad) param.data.copy_(global_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before # A Nesterov momentum here is to do a partial weight update before
# calculating the gradient # calculating the gradient
if self.use_nbm: if self.use_nbm:
param.data.copy_(param.data - self.block_momentum * smoothed_grad) param.data.copy_(param.data - self.block_momentum * smoothed_grad)
# backup for the next synchronization. # backup for the next synchronization.
self.smoothed_grads_localprev[index] = smoothed_grad self.smoothed_grads[index] = smoothed_grad
prev_param.copy_(param.data) global_param.copy_(param.data)
...@@ -108,3 +108,6 @@ class FairseqOptimizer(object): ...@@ -108,3 +108,6 @@ class FairseqOptimizer(object):
if hasattr(self.optimizer, 'supports_memory_efficient_fp16'): if hasattr(self.optimizer, 'supports_memory_efficient_fp16'):
return self.optimizer.supports_memory_efficient_fp16 return self.optimizer.supports_memory_efficient_fp16
return False return False
def average_params(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment