Commit 58e43cb3 authored by Chenyang Yu's avatar Chenyang Yu Committed by Facebook Github Bot
Browse files

extract FP16OptimizerMixin for share the same logic in PyText (#1180)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1180

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/874

extract FP16OptimizerMixin for share the same logic in PyText

Reviewed By: hudeven

Differential Revision: D17594102

fbshipit-source-id: 8625a4e4f3e09cbaba6ae92599c1121b86ed4e78
parent 1c667929
......@@ -54,41 +54,14 @@ class DynamicLossScaler(object):
return False
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
class _FP16OptimizerMixin(object):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
def __init__(self, *args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(*args, **kwargs)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
def build_fp32_params(cls, params):
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
......@@ -99,23 +72,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
return fp32_params
def state_dict(self):
"""Return the optimizer's state dict."""
......@@ -179,14 +136,14 @@ class FP16Optimizer(optim.FairseqOptimizer):
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
if self.scaler.loss_scale <= self.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
).format(self.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
......@@ -211,6 +168,61 @@ class FP16Optimizer(optim.FairseqOptimizer):
self._needs_sync = False
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
)
self.min_loss_scale = self.args.min_loss_scale
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp32_params = cls.build_fp32_params(params)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment