Unverified Commit fffc8757 authored by Y. Xiong's avatar Y. Xiong Committed by GitHub
Browse files

[Feature]: Support auto_fp16 using torch.cuda.amp when PyTorch >= 1.6.0 (#951)

* add torch.cuda.amp to fp16_utils and optimizers

* use with context manager for autocast

* add doc to explain the behavior differences between real amp and ours

* fix docstring
parent 0fc19b46
...@@ -7,8 +7,18 @@ import numpy as np ...@@ -7,8 +7,18 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import TORCH_VERSION
from .dist_utils import allreduce_grads as _allreduce_grads from .dist_utils import allreduce_grads as _allreduce_grads
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
# manually, so the behavior may not be consistant with real amp.
from torch.cuda.amp import autocast
except ImportError:
pass
def cast_tensor_type(inputs, src_type, dst_type): def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type. """Recursively convert Tensor in inputs from src_type to dst_type.
...@@ -45,7 +55,8 @@ def auto_fp16(apply_to=None, out_fp32=False): ...@@ -45,7 +55,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
This decorator is useful when you write custom modules and want to support This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored. ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args: Args:
apply_to (Iterable, optional): The argument names to be converted. apply_to (Iterable, optional): The argument names to be converted.
...@@ -82,6 +93,7 @@ def auto_fp16(apply_to=None, out_fp32=False): ...@@ -82,6 +93,7 @@ def auto_fp16(apply_to=None, out_fp32=False):
'method of nn.Module') 'method of nn.Module')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs) return old_func(*args, **kwargs)
# get the arg spec of the decorated method # get the arg spec of the decorated method
args_info = getfullargspec(old_func) args_info = getfullargspec(old_func)
# get the argument names to be casted # get the argument names to be casted
...@@ -107,7 +119,11 @@ def auto_fp16(apply_to=None, out_fp32=False): ...@@ -107,7 +119,11 @@ def auto_fp16(apply_to=None, out_fp32=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
output = old_func(*new_args, **new_kwargs) if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary # cast the results back to fp32 if necessary
if out_fp32: if out_fp32:
output = cast_tensor_type(output, torch.half, torch.float) output = cast_tensor_type(output, torch.half, torch.float)
...@@ -125,7 +141,9 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -125,7 +141,9 @@ def force_fp32(apply_to=None, out_fp16=False):
mixed precision training. If there are some inputs that must be processed mixed precision training. If there are some inputs that must be processed
in fp32 mode, then this decorator can handle it. If inputs arguments are in fp32 mode, then this decorator can handle it. If inputs arguments are
fp16 tensors, they will be converted to fp32 automatically. Arguments other fp16 tensors, they will be converted to fp32 automatically. Arguments other
than fp16 tensors are ignored. than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
torch.cuda.amp is used as the backend, otherwise, original mmcv
implementation will be adopted.
Args: Args:
apply_to (Iterable, optional): The argument names to be converted. apply_to (Iterable, optional): The argument names to be converted.
...@@ -186,7 +204,11 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -186,7 +204,11 @@ def force_fp32(apply_to=None, out_fp16=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
output = old_func(*new_args, **new_kwargs) if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary # cast the results back to fp32 if necessary
if out_fp16: if out_fp16:
output = cast_tensor_type(output, torch.float, torch.half) output = cast_tensor_type(output, torch.float, torch.half)
...@@ -207,16 +229,25 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): ...@@ -207,16 +229,25 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
def wrap_fp16_model(model): def wrap_fp16_model(model):
"""Wrap the FP32 model to FP16. """Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
For PyTorch >= 1.6, this function will
1. Set fp16 flag inside the model to True.
Otherwise:
1. Convert FP32 model to FP16. 1. Convert FP32 model to FP16.
2. Remain some necessary layers to be FP32, e.g., normalization layers. 2. Remain some necessary layers to be FP32, e.g., normalization layers.
3. Set `fp16_enabled` flag inside the model to True.
Args: Args:
model (nn.Module): Model in FP32. model (nn.Module): Model in FP32.
""" """
# convert model to fp16 if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
model.half() # convert model to fp16
# patch the normalization layers to make it work in fp32 mode model.half()
patch_norm_fp32(model) # patch the normalization layers to make it work in fp32 mode
patch_norm_fp32(model)
# set `fp16_enabled` flag # set `fp16_enabled` flag
for m in model.modules(): for m in model.modules():
if hasattr(m, 'fp16_enabled'): if hasattr(m, 'fp16_enabled'):
......
...@@ -5,10 +5,18 @@ from itertools import chain ...@@ -5,10 +5,18 @@ from itertools import chain
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION
from ..dist_utils import allreduce_grads from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler
except ImportError:
pass
@HOOKS.register_module() @HOOKS.register_module()
class OptimizerHook(Hook): class OptimizerHook(Hook):
...@@ -34,128 +42,237 @@ class OptimizerHook(Hook): ...@@ -34,128 +42,237 @@ class OptimizerHook(Hook):
runner.optimizer.step() runner.optimizer.step()
@HOOKS.register_module() if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
class Fp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook.
The steps of fp16 optimizer is as follows.
1. Scale the loss value.
2. BP in the fp16 model.
2. Copy gradients from fp16 model to fp32 weights.
3. Update fp32 weights.
4. Copy updated parameters from fp32 weights to fp16 model.
Refer to https://arxiv.org/abs/1710.03740 for more details.
Args:
loss_scale (float | str | dict): Scale factor multiplied with loss.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
It can also be a dict containing arguments of LossScaler.
Defaults to 512.
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.distributed = distributed
if loss_scale == 'dynamic':
self.loss_scaler = LossScaler(mode='dynamic')
elif isinstance(loss_scale, float):
self.loss_scaler = LossScaler(init_scale=loss_scale, mode='static')
elif isinstance(loss_scale, dict):
self.loss_scaler = LossScaler(**loss_scale)
else:
raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}')
def before_run(self, runner):
"""Preparing steps before Mixed Precision Training.
1. Make a master copy of fp32 weights for optimization.
2. Convert the main model from fp32 to fp16.
"""
# keep a copy of fp32 weights
old_groups = runner.optimizer.param_groups
runner.optimizer.param_groups = copy.deepcopy(
runner.optimizer.param_groups)
state = defaultdict(dict)
p_map = {
old_p: p
for old_p, p in zip(
chain(*(g['params'] for g in old_groups)),
chain(*(g['params'] for g in runner.optimizer.param_groups)))
}
for k, v in runner.optimizer.state.items():
state[p_map[k]] = v
runner.optimizer.state = state
# convert model to fp16
wrap_fp16_model(runner.model)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
if fp16_param.grad is not None:
if fp32_param.grad is None:
fp32_param.grad = fp32_param.data.new(fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner): @HOOKS.register_module()
"""Backward optimization steps for Mixed Precision Training. For class Fp16OptimizerHook(OptimizerHook):
dynamic loss scaling, please refer `loss_scalar.py` """FP16 optimizer hook (using PyTorch's implementation).
1. Scale the loss by a scale factor. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
2. Backward the loss to obtain the gradients (fp16). to take care of the optimization procedure.
3. Copy gradients from the model to the fp32 weight copy.
4. Scale the gradients back and update the fp32 weight copy. Args:
5. Copy back the params from fp32 weight copy to the fp16 model. loss_scale (float | str | dict): Scale factor configuration.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
It can also be a dict containing arguments of GradScalar.
Defaults to 512. For Pytorch >= 1.6, mmcv uses official
implementation of GradScaler. If you use a dict version of
loss_scale to create GradScaler, plese refer to:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
for the parameters.
Examples:
>>> loss_scale = dict(
... init_scale=65536.0,
... growth_factor=2.0,
... backoff_factor=0.5,
... growth_interval=2000
... )
>>> optimizer = Fp16OptimizerHook(loss_scale=loss_scale)
""" """
# clear grads of last iteration
runner.model.zero_grad() def __init__(self,
runner.optimizer.zero_grad() grad_clip=None,
# scale the loss value coalesce=True,
scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale bucket_size_mb=-1,
scaled_loss.backward() loss_scale=512.,
# copy fp16 grads in the model to fp32 params in the optimizer distributed=True):
self.grad_clip = grad_clip
fp32_weights = [] self.coalesce = coalesce
for param_group in runner.optimizer.param_groups: self.bucket_size_mb = bucket_size_mb
fp32_weights += param_group['params'] self.distributed = distributed
self.copy_grads_to_fp32(runner.model, fp32_weights) self._scale_update_param = None
# allreduce grads if loss_scale == 'dynamic':
if self.distributed: self.loss_scaler = GradScaler()
allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb) elif isinstance(loss_scale, float):
self._scale_update_param = loss_scale
has_overflow = self.loss_scaler.has_overflow(fp32_weights) self.loss_scaler = GradScaler(init_scale=loss_scale)
# if has overflow, skip this iteration elif isinstance(loss_scale, dict):
if not has_overflow: self.loss_scaler = GradScaler(**loss_scale)
# scale the gradients back else:
for param in fp32_weights: raise ValueError('loss_scale must be of type float, dict, or '
if param.grad is not None: f'"dynamic", got {loss_scale}')
param.grad.div_(self.loss_scaler.loss_scale)
def before_run(self, runner):
"""Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16
wrap_fp16_model(runner.model)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
if fp16_param.grad is not None:
if fp32_param.grad is None:
fp32_param.grad = fp32_param.data.new(
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer to
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
1. Scale the loss by a scale factor.
2. Backward the loss to obtain the gradients.
3. Unscale the optimizer’s gradient tensors.
4. Call optimizer.step() and update scale factor.
"""
# clear grads of last iteration
runner.model.zero_grad()
runner.optimizer.zero_grad()
self.loss_scaler.scale(runner.outputs['loss']).backward()
self.loss_scaler.unscale_(runner.optimizer)
# grad clip
if self.grad_clip is not None: if self.grad_clip is not None:
grad_norm = self.clip_grads(fp32_weights) grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None: if grad_norm is not None:
# Add grad norm to the logger # Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)}, runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples']) runner.outputs['num_samples'])
# update fp32 params # backward and update scaler
runner.optimizer.step() self.loss_scaler.step(runner.optimizer)
# copy fp32 params to the fp16 model self.loss_scaler.update(self._scale_update_param)
self.copy_params_to_fp16(runner.model, fp32_weights) else:
self.loss_scaler.update_scale(has_overflow)
if has_overflow: @HOOKS.register_module()
runner.logger.warning('Check overflow, downscale loss scale ' class Fp16OptimizerHook(OptimizerHook):
f'to {self.loss_scaler.cur_scale}') """FP16 optimizer hook (mmcv's implementation).
The steps of fp16 optimizer is as follows.
1. Scale the loss value.
2. BP in the fp16 model.
2. Copy gradients from fp16 model to fp32 weights.
3. Update fp32 weights.
4. Copy updated parameters from fp32 weights to fp16 model.
Refer to https://arxiv.org/abs/1710.03740 for more details.
Args:
loss_scale (float | str | dict): Scale factor configuration.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
It can also be a dict containing arguments of LossScaler.
Defaults to 512.
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.distributed = distributed
if loss_scale == 'dynamic':
self.loss_scaler = LossScaler(mode='dynamic')
elif isinstance(loss_scale, float):
self.loss_scaler = LossScaler(
init_scale=loss_scale, mode='static')
elif isinstance(loss_scale, dict):
self.loss_scaler = LossScaler(**loss_scale)
else:
raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}')
def before_run(self, runner):
"""Preparing steps before Mixed Precision Training.
1. Make a master copy of fp32 weights for optimization.
2. Convert the main model from fp32 to fp16.
"""
# keep a copy of fp32 weights
old_groups = runner.optimizer.param_groups
runner.optimizer.param_groups = copy.deepcopy(
runner.optimizer.param_groups)
state = defaultdict(dict)
p_map = {
old_p: p
for old_p, p in zip(
chain(*(g['params'] for g in old_groups)),
chain(*(g['params']
for g in runner.optimizer.param_groups)))
}
for k, v in runner.optimizer.state.items():
state[p_map[k]] = v
runner.optimizer.state = state
# convert model to fp16
wrap_fp16_model(runner.model)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
if fp16_param.grad is not None:
if fp32_param.grad is None:
fp32_param.grad = fp32_param.data.new(
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer `loss_scalar.py`
1. Scale the loss by a scale factor.
2. Backward the loss to obtain the gradients (fp16).
3. Copy gradients from the model to the fp32 weight copy.
4. Scale the gradients back and update the fp32 weight copy.
5. Copy back the params from fp32 weight copy to the fp16 model.
"""
# clear grads of last iteration
runner.model.zero_grad()
runner.optimizer.zero_grad()
# scale the loss value
scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
scaled_loss.backward()
# copy fp16 grads in the model to fp32 params in the optimizer
fp32_weights = []
for param_group in runner.optimizer.param_groups:
fp32_weights += param_group['params']
self.copy_grads_to_fp32(runner.model, fp32_weights)
# allreduce grads
if self.distributed:
allreduce_grads(fp32_weights, self.coalesce,
self.bucket_size_mb)
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
# if has overflow, skip this iteration
if not has_overflow:
# scale the gradients back
for param in fp32_weights:
if param.grad is not None:
param.grad.div_(self.loss_scaler.loss_scale)
if self.grad_clip is not None:
grad_norm = self.clip_grads(fp32_weights)
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update(
{'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
# update fp32 params
runner.optimizer.step()
# copy fp32 params to the fp16 model
self.copy_params_to_fp16(runner.model, fp32_weights)
self.loss_scaler.update_scale(has_overflow)
if has_overflow:
runner.logger.warning('Check overflow, downscale loss scale '
f'to {self.loss_scaler.cur_scale}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment