Commit ed27ed8b authored by Nayan Singhal's avatar Nayan Singhal Committed by Facebook Github Bot
Browse files

BMUF Resetting local state param

Summary:
BMUF
1) Resetting BMUF parameters after warmup.
2) Resetting local param state after warmup.
3) Allowing user to pass block momentum value instead of gpu derived Block Momentum.

Reviewed By: skritika, mrshenli

Differential Revision: D16692026

fbshipit-source-id: d02eaf29d0e4b37007418166ec937d4bf5fe6aca
parent a8e32111
......@@ -26,11 +26,12 @@ class FairseqBMUF(FairseqOptimizer):
self.params = params
self._num_updates = 0
self.sync_iter = self.args.global_sync_iter
self.block_momentum = 1 - 1.0 / self.args.distributed_world_size
self.block_momentum = self.args.block_momentum
self.block_lr = self.args.block_lr
self._reset_local_data()
self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm
self.initial_state = self._optimizer.state_dict()
@staticmethod
def add_args(parser):
......@@ -38,6 +39,12 @@ class FairseqBMUF(FairseqOptimizer):
parser.add_argument(
"--block-lr", default=1, type=float, help="block learning rate for bmuf"
)
parser.add_argument(
"--block-momentum",
default=0.875,
type=float,
help="block momentum for bmuf",
)
parser.add_argument(
"--global-sync-iter",
default=10,
......@@ -85,8 +92,9 @@ class FairseqBMUF(FairseqOptimizer):
"""Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm)
def _sync_block(self):
if self.get_num_updates() % self.sync_iter == 0:
def _block_sync(self):
# Update the global model using local models from all GPUs.
if self._is_bmuf_iter():
if self.block_momentum != 0:
self._BM_before_sync()
......@@ -95,33 +103,33 @@ class FairseqBMUF(FairseqOptimizer):
if self.block_momentum != 0:
self._BM_after_sync()
def _broadcast_model(self, rootRank=0):
if (
self.warmup_iteration != 0
and self.get_num_updates() % self.warmup_iteration == 0
):
self.warmup_iteration = 0
def _is_warmup_end(self):
if self.get_num_updates() == self.warmup_iteration:
return True
return False
# broadcast the local model
for param in self.params:
dist.broadcast(param.data, rootRank)
def _is_bmuf_iter(self):
if self.get_num_updates() % self.sync_iter == 0:
return True
return False
def _warmup_sync(self, rootRank=0):
# broadcast the local model to all GPUs
for param in self.params:
dist.broadcast(param.data, src=rootRank)
# Also, broadcast the local parameters
for param in (
self.params_localprev
+ self.smoothed_grads_localprev
+ self.grads_localprev
):
dist.broadcast(param, src=rootRank)
# Reset the local optimizer state and local bmuf related param
self._optimizer.load_state_dict(self.initial_state)
self._reset_local_data()
def step(self, closure=None):
"""Performs a single optimization step."""
self._optimizer.step(closure)
self.set_num_updates(self.get_num_updates() + 1)
if self.warmup_iteration != 0:
self._broadcast_model()
if self._is_warmup_end():
self._warmup_sync()
else:
self._sync_block()
self._block_sync()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
......@@ -137,6 +145,7 @@ class FairseqBMUF(FairseqOptimizer):
@torch.no_grad()
def _reset_local_data(self):
"""Resetting all the BMUF specific params."""
self.params_localprev = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads_localprev = [
......@@ -144,12 +153,13 @@ class FairseqBMUF(FairseqOptimizer):
]
self.grads_localprev = [p.data.new_zeros(p.data.size()) for p in self.params]
# initialize
# saving the global model locally for calculating gradient during bmuf sync
for param, copy_param in zip(self.params, self.params_localprev):
copy_param.copy_(param.data)
@torch.no_grad()
def _BM_before_sync(self):
"""Calculate grad between previously synced model and currrent local model."""
# prev_param is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced
......@@ -160,6 +170,7 @@ class FairseqBMUF(FairseqOptimizer):
self.grads_localprev[index] = prev_param - param.data
def _allreduce_parameter(self):
"""Average gradient from all the GPUs. """
for index, param in enumerate(self.params):
sync_para = (
param.data if self.block_momentum == 0 else self.grads_localprev[index]
......@@ -182,13 +193,8 @@ class FairseqBMUF(FairseqOptimizer):
# prev_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy.
# This is essentially a first-order infinite impulse response (IIR)
# filter with the gain (1 - BM)*BM_lr:
# smoothed_grad(t)=BM * smoothed_grad(t-1) + (1 - BM)*BM_lr*grad(t)
smoothed_grad = (
smoothed_grad * self.block_momentum
+ grad * (1 - self.block_momentum) * self.block_lr
)
# smoothed_grad(t)=BM * smoothed_grad(t-1) + BM_lr*grad(t)
smoothed_grad = smoothed_grad * self.block_momentum + grad * self.block_lr
param.data.copy_(prev_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before
# calculating the gradient
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment