Unverified Commit fffc8757 authored by Y. Xiong's avatar Y. Xiong Committed by GitHub
Browse files

[Feature]: Support auto_fp16 using torch.cuda.amp when PyTorch >= 1.6.0 (#951)

* add torch.cuda.amp to fp16_utils and optimizers

* use with context manager for autocast

* add doc to explain the behavior differences between real amp and ours

* fix docstring
parent 0fc19b46
......@@ -7,8 +7,18 @@ import numpy as np
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION
from .dist_utils import allreduce_grads as _allreduce_grads
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
# manually, so the behavior may not be consistant with real amp.
from torch.cuda.amp import autocast
except ImportError:
pass
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
......@@ -45,7 +55,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored.
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
......@@ -82,6 +93,7 @@ def auto_fp16(apply_to=None, out_fp32=False):
'method of nn.Module')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
......@@ -107,6 +119,10 @@ def auto_fp16(apply_to=None, out_fp32=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp32:
......@@ -125,7 +141,9 @@ def force_fp32(apply_to=None, out_fp16=False):
mixed precision training. If there are some inputs that must be processed
in fp32 mode, then this decorator can handle it. If inputs arguments are
fp16 tensors, they will be converted to fp32 automatically. Arguments other
than fp16 tensors are ignored.
than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
torch.cuda.amp is used as the backend, otherwise, original mmcv
implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
......@@ -186,6 +204,10 @@ def force_fp32(apply_to=None, out_fp16=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp16:
......@@ -207,12 +229,21 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
def wrap_fp16_model(model):
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
For PyTorch >= 1.6, this function will
1. Set fp16 flag inside the model to True.
Otherwise:
1. Convert FP32 model to FP16.
2. Remain some necessary layers to be FP32, e.g., normalization layers.
3. Set `fp16_enabled` flag inside the model to True.
Args:
model (nn.Module): Model in FP32.
"""
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
......
......@@ -5,10 +5,18 @@ from itertools import chain
from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION
from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler
except ImportError:
pass
@HOOKS.register_module()
class OptimizerHook(Hook):
......@@ -34,9 +42,111 @@ class OptimizerHook(Hook):
runner.optimizer.step()
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook.
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook (using PyTorch's implementation).
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
to take care of the optimization procedure.
Args:
loss_scale (float | str | dict): Scale factor configuration.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
It can also be a dict containing arguments of GradScalar.
Defaults to 512. For Pytorch >= 1.6, mmcv uses official
implementation of GradScaler. If you use a dict version of
loss_scale to create GradScaler, plese refer to:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
for the parameters.
Examples:
>>> loss_scale = dict(
... init_scale=65536.0,
... growth_factor=2.0,
... backoff_factor=0.5,
... growth_interval=2000
... )
>>> optimizer = Fp16OptimizerHook(loss_scale=loss_scale)
"""
def __init__(self,
grad_clip=None,
coalesce=True,
bucket_size_mb=-1,
loss_scale=512.,
distributed=True):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.distributed = distributed
self._scale_update_param = None
if loss_scale == 'dynamic':
self.loss_scaler = GradScaler()
elif isinstance(loss_scale, float):
self._scale_update_param = loss_scale
self.loss_scaler = GradScaler(init_scale=loss_scale)
elif isinstance(loss_scale, dict):
self.loss_scaler = GradScaler(**loss_scale)
else:
raise ValueError('loss_scale must be of type float, dict, or '
f'"dynamic", got {loss_scale}')
def before_run(self, runner):
"""Preparing steps before Mixed Precision Training."""
# wrap model mode to fp16
wrap_fp16_model(runner.model)
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
if fp16_param.grad is not None:
if fp32_param.grad is None:
fp32_param.grad = fp32_param.data.new(
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
"""Backward optimization steps for Mixed Precision Training. For
dynamic loss scaling, please refer to
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
1. Scale the loss by a scale factor.
2. Backward the loss to obtain the gradients.
3. Unscale the optimizer’s gradient tensors.
4. Call optimizer.step() and update scale factor.
"""
# clear grads of last iteration
runner.model.zero_grad()
runner.optimizer.zero_grad()
self.loss_scaler.scale(runner.outputs['loss']).backward()
self.loss_scaler.unscale_(runner.optimizer)
# grad clip
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
# backward and update scaler
self.loss_scaler.step(runner.optimizer)
self.loss_scaler.update(self._scale_update_param)
else:
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
"""FP16 optimizer hook (mmcv's implementation).
The steps of fp16 optimizer is as follows.
1. Scale the loss value.
......@@ -48,7 +158,7 @@ class Fp16OptimizerHook(OptimizerHook):
Refer to https://arxiv.org/abs/1710.03740 for more details.
Args:
loss_scale (float | str | dict): Scale factor multiplied with loss.
loss_scale (float | str | dict): Scale factor configuration.
If loss_scale is a float, static loss scaling will be used with
the specified scale. If loss_scale is a string, it must be
'dynamic', then dynamic loss scaling will be used.
......@@ -69,7 +179,8 @@ class Fp16OptimizerHook(OptimizerHook):
if loss_scale == 'dynamic':
self.loss_scaler = LossScaler(mode='dynamic')
elif isinstance(loss_scale, float):
self.loss_scaler = LossScaler(init_scale=loss_scale, mode='static')
self.loss_scaler = LossScaler(
init_scale=loss_scale, mode='static')
elif isinstance(loss_scale, dict):
self.loss_scaler = LossScaler(**loss_scale)
else:
......@@ -91,7 +202,8 @@ class Fp16OptimizerHook(OptimizerHook):
old_p: p
for old_p, p in zip(
chain(*(g['params'] for g in old_groups)),
chain(*(g['params'] for g in runner.optimizer.param_groups)))
chain(*(g['params']
for g in runner.optimizer.param_groups)))
}
for k, v in runner.optimizer.state.items():
state[p_map[k]] = v
......@@ -101,15 +213,18 @@ class Fp16OptimizerHook(OptimizerHook):
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
"""Copy gradients from fp16 model to fp32 weight copy."""
for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()):
for fp32_param, fp16_param in zip(fp32_weights,
fp16_net.parameters()):
if fp16_param.grad is not None:
if fp32_param.grad is None:
fp32_param.grad = fp32_param.data.new(fp32_param.size())
fp32_param.grad = fp32_param.data.new(
fp32_param.size())
fp32_param.grad.copy_(fp16_param.grad)
def copy_params_to_fp16(self, fp16_net, fp32_weights):
"""Copy updated params from fp32 weight copy to fp16 model."""
for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights):
for fp16_param, fp32_param in zip(fp16_net.parameters(),
fp32_weights):
fp16_param.data.copy_(fp32_param.data)
def after_train_iter(self, runner):
......@@ -136,7 +251,8 @@ class Fp16OptimizerHook(OptimizerHook):
self.copy_grads_to_fp32(runner.model, fp32_weights)
# allreduce grads
if self.distributed:
allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb)
allreduce_grads(fp32_weights, self.coalesce,
self.bucket_size_mb)
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
# if has overflow, skip this iteration
......@@ -149,7 +265,8 @@ class Fp16OptimizerHook(OptimizerHook):
grad_norm = self.clip_grads(fp32_weights)
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.log_buffer.update(
{'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
# update fp32 params
runner.optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment