Commit ed8236fa authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix for unscale usage in fp16_utils.FP16_Optimizer

parent d137b800
import torch
from torch._six import container_abcs, string_classes
import functools
from apex.fp16_utils import convert_network
from ._amp_state import _amp_state
from .scaler import LossScaler
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
from ..optimizers import FusedAdam
from ..parallel import DistributedDataParallel as apex_DDP
def to_type(dtype, t):
......@@ -71,7 +72,7 @@ def check_optimizers(optimizers):
bad_optim_type = None
if isinstance(optim, FP16_Optimizer_general):
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
if isinstance(model, FP16_Optimizer_for_fused):
if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
......@@ -81,7 +82,7 @@ def check_optimizers(optimizers):
"soon). You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties)."
"on the specified opt_level (and optional overridden properties).")
def _initialize(models, optimizers, properties):
......@@ -141,9 +142,11 @@ def _initialize(models, optimizers, properties):
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizers[i], dynamic_loss_scale=True)
optimizers[i] = FP16_Optimizer_general(optimizers[i],
dynamic_loss_scale=True)
else:
optimizers[i] = FP16_Optimizer(optimizers[i], static_loss_scale=properties.loss_scale)
optimizers[i] = FP16_Optimizer_general(optimizers[i],
static_loss_scale=properties.loss_scale)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
......
......@@ -91,7 +91,7 @@ class O2:
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = torch.float32
properties.keep_batchnorm_fp32 = True
properties.master_weights = True
properties.loss_scale = "dynamic"
properties.fused_optimizer = False
......@@ -174,6 +174,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
enable_ddp_interop=None):
"""
if not enabled:
_amp_state.opt_properties = Properties()
return models, optimizers
if opt_level not in opt_levels:
......@@ -186,7 +187,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("Defaults for this optimization level are:")
print(_amp_state.opt_properties.options)
for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}".format(k, v))
print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs.items():
......@@ -197,7 +198,7 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
print("After processing overrides, optimization options are:")
for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}".format(k, v))
print("{:22} : {}".format(k, v))
return _initialize(models, optimizers, _amp_state.opt_properties)
......@@ -228,7 +229,7 @@ def check_option_consistency(enabled=True,
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in opt_properties.options:
print("{:20} : {}".format(k, v))
print("{:22} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
......@@ -239,4 +240,4 @@ def check_option_consistency(enabled=True,
print("After processing overrides, optimization options are:")
for k, v in opt_properties.options:
print("{:20} : {}".format(k, v))
print("{:22} : {}".format(k, v))
......@@ -45,11 +45,13 @@ def scale_loss(loss,
if isinstance(optimizer, FP16_Optimizer):
optimizer.update_master_grads()
else:
optimizer.loss_scaler.clear_overflow_state()
optimizer.loss_scaler.unscale(
iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups),
loss_scale)
# If overflow_check_on_cpu is False, should_skip will always be False.
# In the future, once I have fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale()
if should_skip:
optimizer_step = optimizer.step
......@@ -101,6 +103,7 @@ class AmpHandle(object):
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
self._default_scaler.clear_overflow_state()
self._default_scaler.unscale(
iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups),
......
......@@ -37,6 +37,7 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
self._cur_loss_scaler().clear_overflow_state()
self._cur_loss_scaler().unscale(
iter_params(self._optimizer.param_groups),
iter_params(self._optimizer.param_groups),
......
......@@ -5,22 +5,18 @@ from ._amp_state import _amp_state
# from apex_C import scale_check_overflow
def scale_check_overflow_python(model_grad, scale, master_grad):
def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
# Exception handling for 18.04 compatibility
try:
if check_overflow:
cpu_sum = float(model_grad.float().sum())
except RuntimeError as instance:
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
if master_grad is not model_grad:
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False
if master_grad is not model_grad: # copy_ probably internally short-circuits this
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False
class LossScaler(object):
warned_no_fused_kernel = False
......@@ -73,12 +69,21 @@ class LossScaler(object):
self._has_overflow = scale_check_overflow_python(
model,
1./scale,
master)
master,
self.dynamic)
if self._has_overflow and self.dynamic:
break
def unscale(self, model_params, master_params, scale):
def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale):
# torch.cuda.nvtx.range_push("unscale")
if self._has_overflow:
# torch.cuda.nvtx.range_pop()
return
# Lots of defensive list processing going on here. Way more less efficient than
# consuming the iterator directly. Need to examine Python overhead.
......@@ -112,12 +117,12 @@ class LossScaler(object):
# Warning: setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
self._overflow_buf.zero_()
# handle case of opt_level O1 and loss_scale 1.0. There's also some
# special-cased yields in scale_loss to potentially short-circuit earlier.
# TODO: Profile and find out if all the O(N) list processing in unscale()
# is a bottleneck.
if scale == 1.0 and all_same and not self.dynamic:
# torch.cuda.nvtx.range_pop()
return
else:
multi_tensor_applier(
......@@ -128,12 +133,14 @@ class LossScaler(object):
else:
self.unscale_grads_python(model_grads, master_grads, scale)
# Break into multiple param groups so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()
# torch.cuda.nvtx.range_pop()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
if self._has_overflow and self.dynamic:
should_skip = True
self._loss_scale /= 2.
......
......@@ -393,7 +393,9 @@ class FP16_Optimizer(object):
if closure is not None:
retval = self._step_with_closure(closure)
else:
# torch.cuda.nvtx.range_push("pytorch optimizer step")
retval = self.optimizer.step()
# torch.cuda.nvtx.range_pop()
self._master_params_to_model_params()
......@@ -502,6 +504,7 @@ class FP16_Optimizer(object):
self.update_master_grads()
def update_master_grads(self):
# torch.cuda.nvtx.range_push("update_master_grads")
"""
Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly
......@@ -514,6 +517,7 @@ class FP16_Optimizer(object):
# self._model_grads_to_master_grads()
# self._downscale_master()
# Use the one-shot multi-tensor apply kernel
self.loss_scaler.clear_overflow_state()
if len(self.all_fp16_params) > 0:
# print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params])
......@@ -534,6 +538,7 @@ class FP16_Optimizer(object):
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
# quit()
self.overflow = self.loss_scaler.update_scale()
# torch.cuda.nvtx.range_pop()
def inspect_master_grad_data(self):
......
......@@ -95,6 +95,7 @@ if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
# Initialize Amp
amp_handle = amp.init(enabled=args.fp16)
......@@ -337,7 +338,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
......
......@@ -99,6 +99,7 @@ if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
def main():
global best_prec1, args
......@@ -344,7 +345,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment