Commit 6982c404 authored by Nayan Singhal's avatar Nayan Singhal Committed by Facebook Github Bot
Browse files

Add Model Averaging

Summary:
Implemented model averaging for fairseq.
Removed the ddp wrapper if global optimizer is provided.
Syncing all the models based on the iteration provide in the input

TODO:
1) Fix throughput and wps meter. Need to check other meters too.
2) Replace Model average code with BMUF algorithm implementation.

Reviewed By: myleott

Differential Revision: D15711044

fbshipit-source-id: 58a4af74db2a61d06762597b95836cbeb1ed82cc
parent 78c2fcf0
...@@ -11,6 +11,7 @@ import os ...@@ -11,6 +11,7 @@ import os
from fairseq import registry from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF
__all__ = [ __all__ = [
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.distributed as dist
from fairseq import optim
from . import FairseqOptimizer
class FairseqBMUF(FairseqOptimizer):
"""
Implements incremental block distributed data parallelism similar to
https://ieeexplore.ieee.org/document/7472805
Paper title: Scalable training of deep learning machines by incremental
block training with intra-block parallel optimization and blockwise
model-update filtering
"""
def __init__(self, args, params, optimizer):
super().__init__(args, params)
self._optimizer = optimizer
self.params = params
self._num_updates = 0
self.sync_iter = self.args.global_sync_iter
@property
def optimizer(self):
return self._optimizer.optimizer
@property
def optimizer_config(self):
return self._optimizer.optimizer_config
def get_lr(self):
return self._optimizer.get_lr()
def set_lr(self, lr):
self._optimizer.set_lr(lr)
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict, optimizer_overrides=None):
self._optimizer.load_state_dict(state_dict, optimizer_overrides)
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
self._optimizer.multiply_grads(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
return self._optimizer.clip_grad_norm(max_norm)
def _model_average_step(self):
if self.get_num_updates() % self.sync_iter == 0:
size = float(dist.get_world_size())
for p in self.params:
dist.all_reduce(p.data, op=dist.reduce_op.SUM)
p.data /= size
def step(self, closure=None):
"""Performs a single optimization step."""
self._optimizer.step(closure)
self.set_num_updates(self.get_num_updates() + 1)
self._model_average_step()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self._optimizer.zero_grad()
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
...@@ -320,6 +320,10 @@ def add_optimization_args(parser): ...@@ -320,6 +320,10 @@ def add_optimization_args(parser):
' (note: this may be interpreted differently depending on --lr-scheduler)') ' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--min-lr', default=-1, type=float, metavar='LR', group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum') help='stop training when the learning rate reaches this minimum')
group.add_argument('--use-bmuf', default=False, action='store_true',
help="specify global optimizer for syncing models on different GPUs/Shards")
group.add_argument('--global-sync-iter', default=10, type=int,
help='Iteration for syncing global model')
# fmt: on # fmt: on
return group return group
......
...@@ -79,7 +79,7 @@ class Trainer(object): ...@@ -79,7 +79,7 @@ class Trainer(object):
@property @property
def model(self): def model(self):
if self._wrapped_model is None: if self._wrapped_model is None:
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
self._wrapped_model = models.DistributedFairseqModel( self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model, self.args, self._model,
) )
...@@ -114,6 +114,9 @@ class Trainer(object): ...@@ -114,6 +114,9 @@ class Trainer(object):
print('| NOTICE: your device may support faster training with --fp16') print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer)
# We should initialize the learning rate scheduler immediately after # We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set. # building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
...@@ -286,7 +289,13 @@ class Trainer(object): ...@@ -286,7 +289,13 @@ class Trainer(object):
return None return None
# gather logging outputs from all replicas # gather logging outputs from all replicas
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1 and (
(not self.args.use_bmuf)
or (
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
):
logging_outputs, sample_sizes, ooms, prev_norms = \ logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list( zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm], [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
...@@ -294,10 +303,12 @@ class Trainer(object): ...@@ -294,10 +303,12 @@ class Trainer(object):
logging_outputs = list(chain.from_iterable(logging_outputs)) logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes)) sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms) ooms = sum(ooms)
assert (
all(norm == prev_norms[0] for norm in prev_norms) if not self.args.use_bmuf:
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms) assert (
), 'Fatal error: gradients are inconsistent between workers' all(norm == prev_norms[0] for norm in prev_norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
), 'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples)) self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples): if ooms == self.args.distributed_world_size * len(samples):
...@@ -319,7 +330,8 @@ class Trainer(object): ...@@ -319,7 +330,8 @@ class Trainer(object):
try: try:
# normalize grads by sample size # normalize grads by sample size
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) if sample_size > 0:
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
# clip grads # clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
...@@ -445,7 +457,7 @@ class Trainer(object): ...@@ -445,7 +457,7 @@ class Trainer(object):
def lr_step(self, epoch, val_loss=None): def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
_lr = self.lr_scheduler.step(epoch, val_loss) self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps # prefer updating the LR based on the number of steps
return self.lr_step_update() return self.lr_step_update()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment