Commit d9836217 authored by Abhimanyu Sharma's avatar Abhimanyu Sharma Committed by Facebook Github Bot
Browse files

Adopt Fairseq MemoryEfficientFP16Optimizer in PyText (#910)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/910

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1124

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1362

Split the Fariseq MemoryEfficientFP16Optimizer class into 2 classes so that it can be easily imported in pytext through a wrapper class.

Iter 2 - fixed some issues to ensure code runs correctly on fblearner.

Iter 3 - fixed review comments, incorrect import and lints.

Iter 4 - fixed pytext test breaks.

Iter 5 - fix pytext test breaks.

Iter 6 - fix comments and refactor based on conversation with chenyang.

Reviewed By: chenyangyu1988

Differential Revision: D18410916

fbshipit-source-id: 5238ee553cd2811ed0573825e1c29000980cc489
parent b85fb035
......@@ -223,71 +223,11 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
self.fp32_optimizer.set_lr(lr)
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
maintain an FP32 copy of the model. We instead expect the optimizer to
convert the gradients to FP32 internally and sync the results back to the
FP16 model params. This significantly reduces memory usage but slightly
increases the time spent in the optimizer.
Since this wrapper depends on specific functionality in the wrapped
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
optimizers can be wrapped. This is determined by the
*supports_memory_efficient_fp16* property.
"""
def __init__(self, args, params, optimizer):
if not optimizer.supports_memory_efficient_fp16:
raise ValueError(
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
)
super().__init__(args)
self.wrapped_optimizer = optimizer
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params)
return cls(args, params, fp16_optimizer)
@property
def optimizer(self):
return self.wrapped_optimizer.optimizer
class _MemoryEfficientFP16OptimizerMixin(object):
@property
def optimizer_config(self):
return self.wrapped_optimizer.optimizer_config
def get_lr(self):
return self.wrapped_optimizer.get_lr()
def set_lr(self, lr):
self.wrapped_optimizer.set_lr(lr)
def __init__(self, *args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(*args, **kwargs)
def state_dict(self):
"""Return the optimizer's state dict."""
......@@ -363,14 +303,14 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
if self.scaler.loss_scale <= self.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
).format(self.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
......@@ -384,3 +324,71 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""Clears the gradients of all optimized parameters."""
self.wrapped_optimizer.zero_grad()
self._grads_are_scaled = False
class MemoryEfficientFP16Optimizer(_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not
maintain an FP32 copy of the model. We instead expect the optimizer to
convert the gradients to FP32 internally and sync the results back to the
FP16 model params. This significantly reduces memory usage but slightly
increases the time spent in the optimizer.
Since this wrapper depends on specific functionality in the wrapped
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain
optimizers can be wrapped. This is determined by the
*supports_memory_efficient_fp16* property.
"""
def __init__(self, args, params, optimizer):
if not optimizer.supports_memory_efficient_fp16:
raise ValueError(
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
)
super().__init__(args)
self.wrapped_optimizer = optimizer
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
self.min_loss_scale = self.args.min_loss_scale
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params)
return cls(args, params, fp16_optimizer)
@property
def optimizer(self):
return self.wrapped_optimizer.optimizer
@property
def optimizer_config(self):
return self.wrapped_optimizer.optimizer_config
def get_lr(self):
return self.wrapped_optimizer.get_lr()
def set_lr(self, lr):
self.wrapped_optimizer.set_lr(lr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment