Commit c246df42 authored by Nayan Singhal's avatar Nayan Singhal Committed by Facebook Github Bot
Browse files

2/N bmuf

Summary:
Added BMUF implementation.

Todo:
1) Add unit test case for testing model averaging and bmuf
2) Add warm before actually start training the model

Reviewed By: jay-mahadeokar

Differential Revision: D15871477

fbshipit-source-id: 866b0aba2d5bea5b65b4438acb49c886c4a87924
parent 1328bc09
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import sys
import time
import torch
import torch.distributed as dist import torch.distributed as dist
from fairseq import optim
from . import FairseqOptimizer from . import FairseqOptimizer
...@@ -29,6 +32,36 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -29,6 +32,36 @@ class FairseqBMUF(FairseqOptimizer):
self.params = params self.params = params
self._num_updates = 0 self._num_updates = 0
self.sync_iter = self.args.global_sync_iter self.sync_iter = self.args.global_sync_iter
self.block_momentum = 1 - 1.0 / self.args.distributed_world_size
self.block_lr = self.args.block_lr
self._reset_local_data()
self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument(
"--block-lr", default=1, type=float, help="block learning rate for bmuf"
)
parser.add_argument(
"--global-sync-iter",
default=10,
type=int,
help="Iteration for syncing global model",
)
parser.add_argument(
"--warmup-iterations",
default=500,
type=int,
help="warmup iterations for model to broadcast",
)
parser.add_argument(
"--use-nbm",
default=True,
action="store_true",
help="Specify whether you want to use classical BM / Nesterov BM",
)
@property @property
def optimizer(self): def optimizer(self):
...@@ -58,18 +91,43 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -58,18 +91,43 @@ class FairseqBMUF(FairseqOptimizer):
"""Clips gradient norm.""" """Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm) return self._optimizer.clip_grad_norm(max_norm)
def _model_average_step(self): def _sync_block(self):
if self.get_num_updates() % self.sync_iter == 0: if self.get_num_updates() % self.sync_iter == 0:
size = float(dist.get_world_size()) if self.block_momentum != 0:
for p in self.params: self._BM_before_sync()
dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
p.data /= size self._allreduce_parameter()
if self.block_momentum != 0:
self._BM_after_sync()
def _broadcast_model(self, rootRank=0):
if (
self.warmup_iteration != 0
and self.get_num_updates() % self.warmup_iteration == 0
):
self.warmup_iteration = 0
# broadcast the local model
for param in self.params:
dist.broadcast(param.data, rootRank)
# Also, broadcast the local parameters
for param in (
self.params_localprev
+ self.smoothed_grads_localprev
+ self.grads_localprev
):
dist.broadcast(param, src=rootRank)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
self._optimizer.step(closure) self._optimizer.step(closure)
self.set_num_updates(self.get_num_updates() + 1) self.set_num_updates(self.get_num_updates() + 1)
self._model_average_step() if self.warmup_iteration != 0:
self._broadcast_model()
else:
self._sync_block()
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
...@@ -82,3 +140,66 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -82,3 +140,66 @@ class FairseqBMUF(FairseqOptimizer):
def set_num_updates(self, num_updates): def set_num_updates(self, num_updates):
"""Set the number of parameters updates.""" """Set the number of parameters updates."""
self._num_updates = num_updates self._num_updates = num_updates
@torch.no_grad()
def _reset_local_data(self):
self.params_localprev = [torch.zeros_like(p.data) for p in self.params]
self.smoothed_grads_localprev = [
p.data.new_zeros(p.data.size()) for p in self.params
]
self.grads_localprev = [p.data.new_zeros(p.data.size()) for p in self.params]
# initialize
for param, copy_param in zip(self.params, self.params_localprev):
copy_param.copy_(param.data)
@torch.no_grad()
def _BM_before_sync(self):
# prev_param is basically the global copy from the previously finished
# synchronisation. param.data is local parameter after block_sync_freq
# for the local gpu. so grad is difference between previously synced
# model and currrent local model.
for index, (param, prev_param) in enumerate(
zip(self.params, self.params_localprev)
):
self.grads_localprev[index] = prev_param - param.data
def _allreduce_parameter(self):
for index, param in enumerate(self.params):
sync_para = (
param.data if self.block_momentum == 0 else self.grads_localprev[index]
)
sync_para /= float(dist.get_world_size())
dist.all_reduce(sync_para, op=dist.ReduceOp.SUM)
@torch.no_grad()
def _BM_after_sync(self):
for index, (param, prev_param, smoothed_grad, grad) in enumerate(
zip(
self.params,
self.params_localprev,
self.smoothed_grads_localprev,
# all machines would share the same value of smoothed_grad, since it is
# always computed on synchronized gradients.
self.grads_localprev,
)
):
# prev_param is basically last syncrhornized parameter. though
# smoothed_grad is local, all processes will have same value of
# smoothed_grad and hence param is globally synchronized copy.
# This is essentially a first-order infinite impulse response (IIR)
# filter with the gain (1 - BM)*BM_lr:
# smoothed_grad(t)=BM * smoothed_grad(t-1) + (1 - BM)*BM_lr*grad(t)
smoothed_grad = (
smoothed_grad * self.block_momentum
+ grad * (1 - self.block_momentum) * self.block_lr
)
param.data.copy_(prev_param - smoothed_grad)
# A Nesterov momentum here is to do a partial weight update before
# calculating the gradient
if self.use_nbm:
param.data.copy_(param.data - self.block_momentum * smoothed_grad)
# backup for the next synchronization.
self.smoothed_grads_localprev[index] = smoothed_grad
prev_param.copy_(param.data)
...@@ -99,6 +99,10 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -99,6 +99,10 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
if hasattr(args, 'task'): if hasattr(args, 'task'):
from fairseq.tasks import TASK_REGISTRY from fairseq.tasks import TASK_REGISTRY
TASK_REGISTRY[args.task].add_args(parser) TASK_REGISTRY[args.task].add_args(parser)
if getattr(args, 'use_bmuf', False):
# hack to support extra args for block distributed data parallelism
from fairseq.optim.bmuf import FairseqBMUF
FairseqBMUF.add_args(parser)
# Parse a second time. # Parse a second time.
if parse_known: if parse_known:
...@@ -322,8 +326,6 @@ def add_optimization_args(parser): ...@@ -322,8 +326,6 @@ def add_optimization_args(parser):
help='stop training when the learning rate reaches this minimum') help='stop training when the learning rate reaches this minimum')
group.add_argument('--use-bmuf', default=False, action='store_true', group.add_argument('--use-bmuf', default=False, action='store_true',
help="specify global optimizer for syncing models on different GPUs/Shards") help="specify global optimizer for syncing models on different GPUs/Shards")
group.add_argument('--global-sync-iter', default=10, type=int,
help='Iteration for syncing global model')
# fmt: on # fmt: on
return group return group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment