Commit ed8236fa authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix for unscale usage in fp16_utils.FP16_Optimizer

parent d137b800
import torch import torch
from torch._six import container_abcs, string_classes from torch._six import container_abcs, string_classes
import functools import functools
from apex.fp16_utils import convert_network
from ._amp_state import _amp_state from ._amp_state import _amp_state
from .scaler import LossScaler from .scaler import LossScaler
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..optimizers import FusedAdam from ..optimizers import FusedAdam
from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t): def to_type(dtype, t):
...@@ -71,7 +72,7 @@ def check_optimizers(optimizers): ...@@ -71,7 +72,7 @@ def check_optimizers(optimizers):
bad_optim_type = None bad_optim_type = None
if isinstance(optim, FP16_Optimizer_general): if isinstance(optim, FP16_Optimizer_general):
bad_optim_type = "apex.fp16_utils.FP16_Optimizer" bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
if isinstance(model, FP16_Optimizer_for_fused): if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer" bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None: if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
...@@ -81,7 +82,7 @@ def check_optimizers(optimizers): ...@@ -81,7 +82,7 @@ def check_optimizers(optimizers):
"soon). You should not manually wrap your optimizer in either \n" "soon). You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n" "apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n" "amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties)." "on the specified opt_level (and optional overridden properties).")
def _initialize(models, optimizers, properties): def _initialize(models, optimizers, properties):
...@@ -141,9 +142,11 @@ def _initialize(models, optimizers, properties): ...@@ -141,9 +142,11 @@ def _initialize(models, optimizers, properties):
if isinstance(optimizer, FusedAdam): if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties) optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic": if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizers[i], dynamic_loss_scale=True) optimizers[i] = FP16_Optimizer_general(optimizers[i],
dynamic_loss_scale=True)
else: else:
optimizers[i] = FP16_Optimizer(optimizers[i], static_loss_scale=properties.loss_scale) optimizers[i] = FP16_Optimizer_general(optimizers[i],
static_loss_scale=properties.loss_scale)
else: else:
for optimizer in optimizers: for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale) optimizer.loss_scaler = LossScaler(properties.loss_scale)
......
...@@ -91,7 +91,7 @@ class O2: ...@@ -91,7 +91,7 @@ class O2:
properties.opt_level = "O2" properties.opt_level = "O2"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.patch_torch_functions = False properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = torch.float32 properties.keep_batchnorm_fp32 = True
properties.master_weights = True properties.master_weights = True
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
properties.fused_optimizer = False properties.fused_optimizer = False
...@@ -174,6 +174,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -174,6 +174,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
enable_ddp_interop=None): enable_ddp_interop=None):
""" """
if not enabled: if not enabled:
_amp_state.opt_properties = Properties()
return models, optimizers return models, optimizers
if opt_level not in opt_levels: if opt_level not in opt_levels:
...@@ -186,7 +187,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -186,7 +187,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("Defaults for this optimization level are:") print("Defaults for this optimization level are:")
print(_amp_state.opt_properties.options) print(_amp_state.opt_properties.options)
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}".format(k, v)) print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...") print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs.items(): for k, v in kwargs.items():
...@@ -197,7 +198,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -197,7 +198,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("After processing overrides, optimization options are:") print("After processing overrides, optimization options are:")
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}".format(k, v)) print("{:22} : {}".format(k, v))
return _initialize(models, optimizers, _amp_state.opt_properties) return _initialize(models, optimizers, _amp_state.opt_properties)
...@@ -228,7 +229,7 @@ def check_option_consistency(enabled=True, ...@@ -228,7 +229,7 @@ def check_option_consistency(enabled=True,
print("Selected optimization level {}", opt_levels[opt_level].brief) print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:") print("Defaults for this optimization level are:")
for k, v in opt_properties.options: for k, v in opt_properties.options:
print("{:20} : {}".format(k, v)) print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...") print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs: for k, v in kwargs:
...@@ -239,4 +240,4 @@ def check_option_consistency(enabled=True, ...@@ -239,4 +240,4 @@ def check_option_consistency(enabled=True,
print("After processing overrides, optimization options are:") print("After processing overrides, optimization options are:")
for k, v in opt_properties.options: for k, v in opt_properties.options:
print("{:20} : {}".format(k, v)) print("{:22} : {}".format(k, v))
...@@ -45,11 +45,13 @@ def scale_loss(loss, ...@@ -45,11 +45,13 @@ def scale_loss(loss,
if isinstance(optimizer, FP16_Optimizer): if isinstance(optimizer, FP16_Optimizer):
optimizer.update_master_grads() optimizer.update_master_grads()
else: else:
optimizer.loss_scaler.clear_overflow_state()
optimizer.loss_scaler.unscale( optimizer.loss_scaler.unscale(
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
loss_scale) loss_scale)
# If overflow_check_on_cpu is False, should_skip will always be False. # In the future, once I have fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale() should_skip = optimizer.loss_scaler.update_scale()
if should_skip: if should_skip:
optimizer_step = optimizer.step optimizer_step = optimizer.step
...@@ -101,6 +103,7 @@ class AmpHandle(object): ...@@ -101,6 +103,7 @@ class AmpHandle(object):
loss_scale = self._default_scaler.loss_scale() loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale yield loss * loss_scale
self._default_scaler.clear_overflow_state()
self._default_scaler.unscale( self._default_scaler.unscale(
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
......
...@@ -37,6 +37,7 @@ class OptimWrapper(object): ...@@ -37,6 +37,7 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale() loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale yield loss * loss_scale
self._cur_loss_scaler().clear_overflow_state()
self._cur_loss_scaler().unscale( self._cur_loss_scaler().unscale(
iter_params(self._optimizer.param_groups), iter_params(self._optimizer.param_groups),
iter_params(self._optimizer.param_groups), iter_params(self._optimizer.param_groups),
......
...@@ -5,22 +5,18 @@ from ._amp_state import _amp_state ...@@ -5,22 +5,18 @@ from ._amp_state import _amp_state
# from apex_C import scale_check_overflow # from apex_C import scale_check_overflow
def scale_check_overflow_python(model_grad, scale, master_grad): def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
try: if check_overflow:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
except RuntimeError as instance:
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
if master_grad is not model_grad:
master_grad.copy_(model_grad) if master_grad is not model_grad: # copy_ probably internally short-circuits this
if scale != 1.0: master_grad.copy_(model_grad)
master_grad.mul_(scale) if scale != 1.0:
return False master_grad.mul_(scale)
return False
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
...@@ -73,12 +69,21 @@ class LossScaler(object): ...@@ -73,12 +69,21 @@ class LossScaler(object):
self._has_overflow = scale_check_overflow_python( self._has_overflow = scale_check_overflow_python(
model, model,
1./scale, 1./scale,
master) master,
self.dynamic)
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
break break
def unscale(self, model_params, master_params, scale): def clear_overflow_state(self):
self._has_overflow = False self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale):
# torch.cuda.nvtx.range_push("unscale")
if self._has_overflow:
# torch.cuda.nvtx.range_pop()
return
# Lots of defensive list processing going on here. Way more less efficient than # Lots of defensive list processing going on here. Way more less efficient than
# consuming the iterator directly. Need to examine Python overhead. # consuming the iterator directly. Need to examine Python overhead.
...@@ -112,12 +117,12 @@ class LossScaler(object): ...@@ -112,12 +117,12 @@ class LossScaler(object):
# Warning: setting this to True unconditionally allows the possibility of an escape # Warning: setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration. # if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
self._overflow_buf.zero_()
# handle case of opt_level O1 and loss_scale 1.0. There's also some # handle case of opt_level O1 and loss_scale 1.0. There's also some
# special-cased yields in scale_loss to potentially short-circuit earlier. # special-cased yields in scale_loss to potentially short-circuit earlier.
# TODO: Profile and find out if all the O(N) list processing in unscale() # TODO: Profile and find out if all the O(N) list processing in unscale()
# is a bottleneck. # is a bottleneck.
if scale == 1.0 and all_same and not self.dynamic: if scale == 1.0 and all_same and not self.dynamic:
# torch.cuda.nvtx.range_pop()
return return
else: else:
multi_tensor_applier( multi_tensor_applier(
...@@ -128,12 +133,14 @@ class LossScaler(object): ...@@ -128,12 +133,14 @@ class LossScaler(object):
else: else:
self.unscale_grads_python(model_grads, master_grads, scale) self.unscale_grads_python(model_grads, master_grads, scale)
# Break into multiple param groups so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() self._has_overflow = self._overflow_buf.item()
# torch.cuda.nvtx.range_pop()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
should_skip = True should_skip = True
self._loss_scale /= 2. self._loss_scale /= 2.
......
...@@ -393,7 +393,9 @@ class FP16_Optimizer(object): ...@@ -393,7 +393,9 @@ class FP16_Optimizer(object):
if closure is not None: if closure is not None:
retval = self._step_with_closure(closure) retval = self._step_with_closure(closure)
else: else:
# torch.cuda.nvtx.range_push("pytorch optimizer step")
retval = self.optimizer.step() retval = self.optimizer.step()
# torch.cuda.nvtx.range_pop()
self._master_params_to_model_params() self._master_params_to_model_params()
...@@ -502,6 +504,7 @@ class FP16_Optimizer(object): ...@@ -502,6 +504,7 @@ class FP16_Optimizer(object):
self.update_master_grads() self.update_master_grads()
def update_master_grads(self): def update_master_grads(self):
# torch.cuda.nvtx.range_push("update_master_grads")
""" """
Copy the ``.grad`` attribute from stored references to fp16 parameters to Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly the ``.grad`` attribute of the fp32 master parameters that are directly
...@@ -514,6 +517,7 @@ class FP16_Optimizer(object): ...@@ -514,6 +517,7 @@ class FP16_Optimizer(object):
# self._model_grads_to_master_grads() # self._model_grads_to_master_grads()
# self._downscale_master() # self._downscale_master()
# Use the one-shot multi-tensor apply kernel # Use the one-shot multi-tensor apply kernel
self.loss_scaler.clear_overflow_state()
if len(self.all_fp16_params) > 0: if len(self.all_fp16_params) > 0:
# print("Model grads before") # print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params]) # print([param.grad.data for param in self.all_fp16_params])
...@@ -534,6 +538,7 @@ class FP16_Optimizer(object): ...@@ -534,6 +538,7 @@ class FP16_Optimizer(object):
# print([param.grad.data for param in self.all_fp32_from_fp32_params]) # print([param.grad.data for param in self.all_fp32_from_fp32_params])
# quit() # quit()
self.overflow = self.loss_scaler.update_scale() self.overflow = self.loss_scaler.update_scale()
# torch.cuda.nvtx.range_pop()
def inspect_master_grad_data(self): def inspect_master_grad_data(self):
......
...@@ -95,6 +95,7 @@ if args.deterministic: ...@@ -95,6 +95,7 @@ if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
torch.manual_seed(args.local_rank) torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
# Initialize Amp # Initialize Amp
amp_handle = amp.init(enabled=args.fp16) amp_handle = amp.init(enabled=args.fp16)
...@@ -337,7 +338,7 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -337,7 +338,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), epoch, i, len(train_loader),
......
...@@ -99,6 +99,7 @@ if args.deterministic: ...@@ -99,6 +99,7 @@ if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
torch.manual_seed(args.local_rank) torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
def main(): def main():
global best_prec1, args global best_prec1, args
...@@ -344,7 +345,7 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -344,7 +345,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), epoch, i, len(train_loader),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment