Commit 843cdbe0 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in master

parents 724672d7 28097c99
...@@ -10,9 +10,7 @@ users as quickly as possible. ...@@ -10,9 +10,7 @@ users as quickly as possible.
# Contents # Contents
## 1. Mixed Precision ## 1. Amp: Automatic Mixed Precision
### Amp: Automatic Mixed Precision
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script. `apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying Users can easily experiment with different pure and mixed precision training modes by supplying
...@@ -78,7 +76,7 @@ It's often convenient to use Apex in Docker containers. Compatible options incl ...@@ -78,7 +76,7 @@ It's often convenient to use Apex in Docker containers. Compatible options incl
For performance and full functionality, we recommend installing Apex with For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via CUDA and C++ extensions via
``` ```
$ git clone apex $ git clone https://github.com/NVIDIA/apex
$ cd apex $ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
``` ```
...@@ -95,6 +93,5 @@ A Python-only build omits: ...@@ -95,6 +93,5 @@ A Python-only build omits:
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### Windows support ### Windows support
Windows support is experimental, and Linux is recommended. `python setup.py install --cpp_ext --cuda_ext` may work if you were able to build Pytorch from source Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. `python setup.py install` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
make sure to install Apex in that same environment.
...@@ -2,4 +2,4 @@ from .amp import init, half_function, float_function, promote_function,\ ...@@ -2,4 +2,4 @@ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts from .handle import scale_loss, disable_casts
from .frontend import initialize from .frontend import initialize
from ._amp_state import master_params from ._amp_state import master_params, _amp_state
...@@ -2,13 +2,29 @@ ...@@ -2,13 +2,29 @@
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like. # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
# But apparently it's ok: # But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
import os
import torch
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0:
import collections.abc as container_abcs
else:
from torch._six import container_abcs
class AmpState(object): class AmpState(object):
def __init__(self): def __init__(self):
self.hard_override=False self.hard_override=False
self.allow_incoming_model_not_fp32 = False
self.verbosity=1
# Attribute stash. Could also just stash things as global module attributes. # Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState() _amp_state = AmpState()
def warn_or_err(msg): def warn_or_err(msg):
if _amp_state.hard_override: if _amp_state.hard_override:
print("Warning: " + msg) print("Warning: " + msg)
...@@ -18,11 +34,30 @@ def warn_or_err(msg): ...@@ -18,11 +34,30 @@ def warn_or_err(msg):
# + " If you're sure you know what you're doing, supply " + # + " If you're sure you know what you're doing, supply " +
# "hard_override=True to amp.initialize.") # "hard_override=True to amp.initialize.")
distributed = False
if 'WORLD_SIZE' in os.environ:
distributed = int(os.environ['WORLD_SIZE']) > 1
def maybe_print(msg, rank0=False):
if _amp_state.verbosity > 0:
if rank0:
if distributed:
if torch.distributed.get_rank() == 0:
print(msg)
else:
print(msg)
else:
print(msg)
# def iter_params(param_groups): # def iter_params(param_groups):
# for group in param_groups: # for group in param_groups:
# for p in group['params']: # for p in group['params']:
# yield p # yield p
def master_params(optimizer): def master_params(optimizer):
""" """
Generator expression that iterates over the params owned by ``optimizer``. Generator expression that iterates over the params owned by ``optimizer``.
......
import torch import torch
from torch._six import container_abcs, string_classes from torch._six import string_classes
import functools import functools
from ._amp_state import _amp_state, warn_or_err import numpy as np
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts from .handle import disable_casts
from .scaler import LossScaler from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
...@@ -15,11 +18,11 @@ def to_type(dtype, t): ...@@ -15,11 +18,11 @@ def to_type(dtype, t):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
if not t.is_cuda: if not t.is_cuda:
# This should not be a hard error, since it may be legitimate. # This should not be a hard error, since it may be legitimate.
print("Warning: An input tensor was not cuda. ") warnings.warn("An input tensor was not cuda.")
if t.requires_grad: # GANs require this.
# This should be a hard-ish error. # if t.requires_grad:
warn_or_err("input data requires grad. Since input data is not a model parameter,\n" # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.") # "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point(): if t.is_floating_point():
return t.to(dtype) return t.to(dtype)
return t return t
...@@ -34,6 +37,8 @@ def applier(value, fn): ...@@ -34,6 +37,8 @@ def applier(value, fn):
return fn(value) return fn(value)
elif isinstance(value, string_classes): elif isinstance(value, string_classes):
return value return value
elif isinstance(value, np.ndarray):
return value
elif isinstance(value, container_abcs.Mapping): elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()} return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable): elif isinstance(value, container_abcs.Iterable):
...@@ -70,18 +75,32 @@ def check_models(models): ...@@ -70,18 +75,32 @@ def check_models(models):
def check_params_fp32(models): def check_params_fp32(models):
for model in models: for model in models:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor": if param.is_floating_point():
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" if 'Half' in param.type():
"When using amp.initialize, you do not need to call .half() on your model\n" warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"before passing it, no matter what optimization level you choose.".format( "When using amp.initialize, you do not need to call .half() on your model\n"
name, param.type())) "before passing it, no matter what optimization level you choose.".format(
name, param.type()))
elif not param.is_cuda:
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with parameters\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, param.type()))
for name, buf in model.named_buffers(): for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor": if buf.is_floating_point():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" if 'Half' in buf.type():
"When using amp.initialize, you do not need to call .half() on your model\n" warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"before passing it, no matter what optimization level you choose.".format( "When using amp.initialize, you do not need to call .half() on your model\n"
name, buf.type())) "before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
elif not buf.is_cuda:
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.initialize, you need to provide a model with buffers\n"
"located on a CUDA device before passing it no matter what optimization level\n"
"you chose. Use model.to('cuda') to use the default device.".format(
name, buf.type()))
def check_optimizers(optimizers): def check_optimizers(optimizers):
...@@ -118,13 +137,15 @@ def wrap_fused_adam(optimizer, properties): ...@@ -118,13 +137,15 @@ def wrap_fused_adam(optimizer, properties):
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale) return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties): def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer): if isinstance(optimizers, torch.optim.Optimizer):
optimizers_was_list = False
optimizers = [optimizers] optimizers = [optimizers]
elif optimizers is None:
optimizers = []
elif isinstance(optimizers, list): elif isinstance(optimizers, list):
optimizers_was_list = True optimizers_was_list = True
else: else:
...@@ -140,8 +161,9 @@ def _initialize(models, optimizers, properties): ...@@ -140,8 +161,9 @@ def _initialize(models, optimizers, properties):
check_models(models) check_models(models)
check_params_fp32(models) if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
check_optimizers(optimizers) check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
...@@ -155,41 +177,54 @@ def _initialize(models, optimizers, properties): ...@@ -155,41 +177,54 @@ def _initialize(models, optimizers, properties):
for model in models: for model in models:
model.to(properties.cast_model_type) model.to(properties.cast_model_type)
caster = functools.partial(to_type, properties.cast_model_type) input_caster = functools.partial(to_type, properties.cast_model_type)
if cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
else:
output_caster = functools.partial(to_type, torch.float32)
# Patch the forward method to cast incoming data to the correct type. for model in models:
# I like writing things explicitly more than decorators. # Patch the forward method to cast incoming data to the correct type, and
def patch_forward(old_fwd): # outgoing data to float32, so "the user never needs to call .half()."
def new_fwd(*args, **kwargs): # I like writing things explicitly more than decorators.
return old_fwd(*applier(args, caster), def patch_forward(old_fwd):
**applier(kwargs, caster)) def new_fwd(*args, **kwargs):
return new_fwd output = old_fwd(*applier(args, input_caster),
**applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers: for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
for model in models:
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*args, **kwargs)
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties)
if properties.master_weights: _amp_state.loss_scalers = []
for i, optimizer in enumerate(optimizers): for _ in range(num_losses):
if isinstance(optimizer, FusedAdam): _amp_state.loss_scalers.append(LossScaler(properties.loss_scale))
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True,
verbose=False)
else:
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale,
verbose=False)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
if properties.patch_torch_functions: if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway. # handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale) handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers: for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be # Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway. # applied to FP32 master params anyway.
...@@ -209,6 +244,12 @@ def _initialize(models, optimizers, properties): ...@@ -209,6 +244,12 @@ def _initialize(models, optimizers, properties):
return models[0], optimizers return models[0], optimizers
else: else:
if models_was_list: if models_was_list:
return models, optimizers[0] if len(optimizers) == 0:
return models
else:
return models, optimizers[0]
else: else:
return models[0], optimizers[0] if len(optimizers) == 0:
return models[0]
else:
return models[0], optimizers[0]
import types
from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
import torch
class AmpOptimizerState(object):
def __init__(self):
pass
def lazy_init_with_master_weights(self):
stash = self._amp_stash
stash.fp16_groups = []
stash.fp32_from_fp16_groups = []
stash.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.param_groups):
# maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
# .format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.state:
self.state[master_param] = self.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
# maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
# .format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
stash.all_fp16_params = []
for group in stash.fp16_groups:
stash.all_fp16_params += group
stash.all_fp32_from_fp16_params = []
for group in stash.fp32_from_fp16_groups:
stash.all_fp32_from_fp16_params += group
stash.all_fp32_from_fp32_params = []
for group in stash.fp32_from_fp32_groups:
stash.all_fp32_from_fp32_params += group
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
for param in stash.all_fp32_from_fp16_params:
param.grad = None
for param in stash.all_fp32_from_fp32_params:
param.grad = None
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.load_state_dict(self.state_dict())
def prepare_backward_with_master_weights(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision:
param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# stash.all_fp32_from_fp16_grad_stash[i] = param.grad
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights(self, scaler):
stash = self._amp_stash
# This is a lot of python overhead...
fp16_grads_needing_unscale = []
new_fp32_grads = []
fp16_grads_needing_unscale_with_stash = []
preexisting_fp32_grads = []
for fp16_param, fp32_param in zip(stash.all_fp16_params,
stash.all_fp32_from_fp16_params):
if fp16_param.grad is None and fp32_param.grad is not None:
continue
elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None:
fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
preexisting_fp32_grads.append(fp32_param.grad)
else: # fp16_param.grad is None and fp32_param.grad is None:
continue
if len(fp16_grads_needing_unscale) > 0:
scaler.unscale(
fp16_grads_needing_unscale,
new_fp32_grads,
scaler.loss_scale(),
models_are_masters=False)
if len(fp16_grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
fp16_grads_needing_unscale_with_stash,
preexisting_fp32_grads,
preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params,
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
def lazy_init_no_master_weights(self):
stash = self._amp_stash
stash.all_fp16_params = []
stash.all_fp32_params = []
for i, param_group in enumerate(self.param_groups):
for i, param in enumerate(param_group['params']):
if param.type() == 'torch.cuda.HalfTensor':
stash.all_fp16_params.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
def prepare_backward_no_master_weights(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_params):
stash.all_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_no_master_weights(self, scaler):
stash = self._amp_stash
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types:
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def _process_optimizer(optimizer, properties):
if hasattr(optimizer, "_amp_stash"):
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
else:
optimizer._amp_stash = AmpOptimizerState()
optimizer._amp_stash.lazy_init_called = False
optimizer._amp_stash.already_patched = False
optimizer._amp_stash.params_have_scaled_gradients = False
for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params",
"_prepare_amp_backward",
"_post_amp_backward"):
if hasattr(optimizer, name):
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
# TODO: Centralize exposure and import error checking for the C backend.
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_with_master_weights, optimizer)
optimizer._master_params_to_model_params = types.MethodType(
_master_params_to_model_params, optimizer)
old_step = optimizer.step
def new_step(self):
retval = old_step()
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
return retval
optimizer.step = types.MethodType(new_step, optimizer)
old_zero_grad = optimizer.zero_grad
def new_zero_grad(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Zero the model grads.
for param in stash.all_fp16_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
for param in stash.all_fp32_from_fp32_params:
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
# Clear the master grads that are independent of model grads
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
else:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
return optimizer
import torch import torch
from ._initialize import _initialize from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err from ._amp_state import _amp_state, warn_or_err, maybe_print
class Properties(object): class Properties(object):
...@@ -53,12 +53,13 @@ class Properties(object): ...@@ -53,12 +53,13 @@ class Properties(object):
# print("setting {} {}".format(name, value)) # print("setting {} {}".format(name, value))
if name == "cast_model_type": if name == "cast_model_type":
if self.opt_level == "O1" and value is not None: if self.opt_level == "O1" and value is not None:
if value is not torch.float32: if value is not False:
warn_or_err("O1 inserts casts around Torch functions rather than " if value is not torch.float32:
"model weights, so with O1, the model weights themselves " warn_or_err("O1 inserts casts around Torch functions rather than "
"should remain FP32. If you wish to cast the model to a " "model weights, so with O1, the model weights themselves "
"different type, use opt_level='O2' or 'O3'. " + "should remain FP32. If you wish to cast the model to a "
"cast_model_type was {}".format(value)) "different type, use opt_level='O2' or 'O3'. " +
"cast_model_type was {}".format(value))
self.options[name] = value self.options[name] = value
elif name == "patch_torch_functions": elif name == "patch_torch_functions":
if self.opt_level != "O1" and value: if self.opt_level != "O1" and value:
...@@ -192,34 +193,45 @@ opt_levels = {"O3": O3(), ...@@ -192,34 +193,45 @@ opt_levels = {"O3": O3(),
# allow user to directly pass Properties struct as well? # allow user to directly pass Properties struct as well?
def initialize( def initialize(
models, models,
optimizers, optimizers=None,
enabled=True, enabled=True,
opt_level=None, opt_level=None,
cast_model_type=None, cast_model_type=None,
patch_torch_functions=None, patch_torch_functions=None,
keep_batchnorm_fp32=None, keep_batchnorm_fp32=None,
master_weights=None, master_weights=None,
loss_scale=None loss_scale=None,
cast_model_outputs=None,
num_losses=1,
verbosity=1,
): ):
""" """
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
chosen ``opt_level`` and overridden properties, if any. chosen ``opt_level`` and overridden properties, if any.
``amp.initialize`` must be called **after** you have finished constructing your model(s) and ``amp.initialize`` should be called **after** you have finished
constructing your model(s) and
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper. optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
See `Distributed training`_ in the Imagenet example. See `Distributed training`_ in the Imagenet example.
Currently, ``amp.initialize`` should only be called **once**,
although it can process an arbitrary number of
models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
If you think your use case requires ``amp.initialize`` to be called more than once,
`let us know`_.
Any property keyword argument that is not ``None`` will be interpreted as a manual override. Any property keyword argument that is not ``None`` will be interpreted as a manual override.
To prevent having to rewrite anything else in your script, name the returned models/optimizers To prevent having to rewrite anything else in your script, name the returned models/optimizers
to replace the passed models/optimizers, as in the Usage below. to replace the passed models/optimizers, as in the code sample below.
Args: Args:
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast. models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
optimizers (torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast. optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
REQUIRED for training, optional for inference.
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present. should run as if Amp were not present.
opt_level(str, required): Pure or mixed precision optimization level. Accepted values are opt_level (str, required): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above. "O0", "O1", "O2", and "O3", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above. above.
...@@ -227,15 +239,25 @@ def initialize( ...@@ -227,15 +239,25 @@ def initialize(
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
passed as a string, must be the string "True" or "False". passed as a string, must be the string "True" or "False".
master_weights (bool, optional, default=None): Optional property override. master_weights (bool, optional, default=None): Optional property override.
loss_scale(float or str, default=None): Optional property override. If passed as a string, loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
must be a string representing a number, e.g., "128.0", or the string "dynamic". must be a string representing a number, e.g., "128.0", or the string "dynamic".
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
of your model(s) are always cast to a particular type regardless of ``opt_level``.
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
which can improve stability. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
support multiple losses/backward passes, but use a single global loss scale
for all of them.
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
Returns: Returns:
Model(s) and optimizer(s) modified according to the ``opt_level``. Model(s) and optimizer(s) modified according to the ``opt_level``.
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
also be a list. also be a list.
Usage:: Permissible invocations::
model, optim = amp.initialize(model, optim,...) model, optim = amp.initialize(model, optim,...)
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...) model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
...@@ -265,9 +287,20 @@ def initialize( ...@@ -265,9 +287,20 @@ def initialize(
.. _`Imagenet example`: .. _`Imagenet example`:
https://github.com/NVIDIA/apex/tree/master/examples/imagenet https://github.com/NVIDIA/apex/tree/master/examples/imagenet
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
.. _`Advanced Amp Usage topic`:
https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
.. _`let us know`:
https://github.com/NVIDIA/apex/issues
""" """
_amp_state.opt_properties = Properties()
_amp_state.verbosity = verbosity
if not enabled: if not enabled:
_amp_state.opt_properties = Properties()
return models, optimizers return models, optimizers
if opt_level not in opt_levels: if opt_level not in opt_levels:
...@@ -275,16 +308,15 @@ def initialize( ...@@ -275,16 +308,15 @@ def initialize(
"Unexpected optimization level {}. ".format(opt_level) + "Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'.") "Options are 'O0', 'O1', 'O2', 'O3'.")
else: else:
_amp_state.opt_properties = opt_levels[opt_level](Properties()) _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
print("Selected optimization level {}".format(opt_levels[opt_level].brief)) maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
print("Defaults for this optimization level are:") maybe_print("Defaults for this optimization level are:", True)
print(_amp_state.opt_properties.options)
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
print("{:22} : {}".format(k, v)) maybe_print("{:22} : {}".format(k, v), True)
print("Processing user overrides (additional kwargs that are not None)...") maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
# I chose to have the keyword arguments listed directly in the argument list, so I # I chose to have the keyword arguments listed directly in the argument list,
# can't use kwargs.items() here. # instead of **kwargs, so I can't use kwargs.items() here.
if enabled is not None: if enabled is not None:
_amp_state.opt_properties.enabled = enabled _amp_state.opt_properties.enabled = enabled
if opt_level is not None: if opt_level is not None:
...@@ -300,11 +332,11 @@ def initialize( ...@@ -300,11 +332,11 @@ def initialize(
if loss_scale is not None: if loss_scale is not None:
_amp_state.opt_properties.loss_scale = loss_scale _amp_state.opt_properties.loss_scale = loss_scale
print("After processing overrides, optimization options are:") maybe_print("After processing overrides, optimization options are:", True)
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
print("{:22} : {}".format(k, v)) maybe_print("{:22} : {}".format(k, v), True)
return _initialize(models, optimizers, _amp_state.opt_properties) return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
# TODO: is this necessary/useful? # TODO: is this necessary/useful?
......
import contextlib import contextlib
import logging
import warnings import warnings
import torch import torch
from . import utils from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
...@@ -14,7 +13,8 @@ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused ...@@ -14,7 +13,8 @@ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager @contextlib.contextmanager
def scale_loss(loss, def scale_loss(loss,
optimizer, optimizers,
loss_id=0,
model=None, model=None,
delay_unscale=False): delay_unscale=False):
""" """
...@@ -38,41 +38,60 @@ def scale_loss(loss, ...@@ -38,41 +38,60 @@ def scale_loss(loss,
unscaled. The direct ``.grad`` attributes of any FP16 unscaled. The direct ``.grad`` attributes of any FP16
model params will remain scaled after context manager exit. model params will remain scaled after context manager exit.
This subtlety affects gradient clipping. See "Gradient clipping" under This subtlety affects gradient clipping. See "Gradient clipping" under
"Advanced use cases" for best practices. `Advanced Amp Usage`_ for best practices.
Args: Args:
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
manager yields is simply ``loss.float()*loss_scale``, so in principle manager yields is simply ``loss.float()*loss_scale``, so in principle
``loss`` could have more than one element, as long as you call ``loss`` could have more than one element, as long as you call
``backward()`` on ``scaled_loss`` appropriately within the context manager body. ``backward()`` on ``scaled_loss`` appropriately within the context manager body.
optimizer: Must be an optimizer returned from an earlier call to ``amp.initialize``. optimizers: All optimizer(s) for which the current backward pass is creating gradients.
Must be an optimizer or list of optimizers returned from an earlier call
to ``amp.initialize``. For example use with multiple optimizers, see
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
being used for the current backward pass. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations. optimizations.
delay_unscale(bool, default=False): Don't unscale the gradients or perform model->master delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
gradient copies on context manager exit. "Advanced use cases" illustrates the default value of ``False`` is strongly recommended.
situations where this is necessary. If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit.
.. warning::If ``True``, ``optimizer.step()`` cannot be ``delay_unscale=True`` is a minor ninja performance optimization and can result
called yet after context manager exit, and must wait for another, later backward context in weird gotchas (especially with multiple models/optimizers/losses),
manager invocation with ``delay_unscale`` left to False. so only use it if you know what you're doing.
See "Advanced use cases" for examples. "Gradient accumulation across iterations" under `Advanced Amp Usage`_
illustrates a situation where this CAN (but does not need to) be used.
.. warning::
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
called yet after context manager exit, and must wait for another, later backward context
manager invocation with ``delay_unscale`` left to False.
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
""" """
if not _amp_state.opt_properties.enabled: if not _amp_state.opt_properties.enabled:
yield loss yield loss
return return
if optimizer.loss_scaler is None: if isinstance(optimizers, torch.optim.Optimizer):
raise RuntimeError("optimizer passed to scale_loss does not have a loss_scaler.") optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API... # this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler. # TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizer, FP16_Optimizer_for_fused): if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizer.cur_scale loss_scale = optimizers.cur_scale
else: else:
loss_scale = optimizer.loss_scaler.loss_scale() loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights) if ((not _amp_state.opt_properties.master_weights)
and (not optimizer.loss_scaler.dynamic) and (not loss_scaler.dynamic)
and loss_scale == 1.0): and loss_scale == 1.0):
yield loss.float() yield loss.float()
# Needing to drop the cache here as well is an ugly gotcha. # Needing to drop the cache here as well is an ugly gotcha.
...@@ -82,35 +101,48 @@ def scale_loss(loss, ...@@ -82,35 +101,48 @@ def scale_loss(loss,
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
return return
if not delay_unscale:
if isinstance(optimizers, list):
for optimizer in optimizers:
if not optimizer._amp_stash.params_have_scaled_gradients:
optimizer._prepare_amp_backward()
yield (loss.float())*loss_scale yield (loss.float())*loss_scale
# this isn't pretty but it unifies things. Once I deprecate the old API entirely, if delay_unscale:
# I will have freedom to clean this up. Maybe instead of wrapping optimizers, for optimizer in optimizers:
# I can simply construct a set of attributes (e.g. master params) and assign them optimizer._amp_stash.params_have_scaled_gradients = True
# directly to optimizer instances. else:
if not delay_unscale: # FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
# The FP16_Optimizer for FusedAdam will take care of unscaling as part of if not isinstance(optimizers, FP16_Optimizer_for_fused):
# its step() method. loss_scaler.clear_overflow_state()
if not isinstance(optimizer, FP16_Optimizer_for_fused): for optimizer in optimizers:
if isinstance(optimizer, FP16_Optimizer_general): optimizer._post_amp_backward(loss_scaler)
optimizer.update_master_grads() optimizer._amp_stash.params_have_scaled_gradients = False
else: # For future fused optimizers that enable sync-free dynamic loss scaling,
optimizer.loss_scaler.clear_overflow_state() # should_skip will always be False.
optimizer.loss_scaler.unscale( should_skip = loss_scaler.update_scale()
master_params(optimizer), if should_skip:
master_params(optimizer), for optimizer in optimizers:
loss_scale) if not optimizer._amp_stash.already_patched:
# For future fused optimizers that enable sync-free dynamic loss scaling, # Close on loss_scaler and loss_id as well, to be safe. Probably not
# should_skip will always be False. # necessary because amp.scale_loss is already creating a temporary scope.
should_skip = optimizer.loss_scaler.update_scale() def patch_step(opt, loss_scaler, loss_id):
if should_skip: opt_step = opt.step
optimizer_step = optimizer.step def skip_step():
def skip_step(): maybe_print(("Gradient overflow. Skipping step, loss scaler " +
logger = logging.getLogger('apex.amp') "{} reducing loss scale to {}").format(loss_id,
logger.warning("Gradient overflow. Skipping step, reducing " + loss_scaler.loss_scale()))
"loss scale to {}".format(optimizer.loss_scaler.loss_scale())) if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
optimizer.step = optimizer_step # Clear the master grads that wouldn't be zeroed by model.zero_grad()
optimizer.step = skip_step for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None
opt.step = opt_step
opt._amp_stash.already_patched = False
return skip_step
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
optimizer._amp_stash.already_patched = True
# Probably ok to skip this if not delay_unscale # Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions: if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
...@@ -149,6 +181,10 @@ class AmpHandle(object): ...@@ -149,6 +181,10 @@ class AmpHandle(object):
@contextlib.contextmanager @contextlib.contextmanager
def scale_loss(self, loss, optimizer): def scale_loss(self, loss, optimizer):
raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
"documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
if not self.is_active(): if not self.is_active():
yield loss yield loss
return return
...@@ -171,8 +207,7 @@ class AmpHandle(object): ...@@ -171,8 +207,7 @@ class AmpHandle(object):
if should_skip: if should_skip:
optimizer_step = optimizer.step optimizer_step = optimizer.step
def skip_step(): def skip_step():
logger = logging.getLogger('apex.amp') maybe_print('Gradient overflow, skipping update')
logger.warning('Gradient overflow, skipping update')
optimizer.step = optimizer_step optimizer.step = optimizer_step
optimizer.step = skip_step optimizer.step = skip_step
......
...@@ -27,6 +27,10 @@ FP16_FUNCS = [ ...@@ -27,6 +27,10 @@ FP16_FUNCS = [
] ]
FP32_FUNCS = [ FP32_FUNCS = [
# Interpolation/Upsampling
'interpolate',
# Pointwise # Pointwise
'softplus', 'softplus',
'softmin', 'softmin',
......
import torch import torch
from .. import utils
MODULE = torch MODULE = torch
FP16_FUNCS = [ FP16_FUNCS = [
# Math # Low level functions wrapped by torch.nn layers.
# TODO: why are these in top-level torch namespace? # The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d', 'conv1d',
'conv2d', 'conv2d',
'conv3d', 'conv3d',
...@@ -12,6 +15,7 @@ FP16_FUNCS = [ ...@@ -12,6 +15,7 @@ FP16_FUNCS = [
'conv_transpose2d', 'conv_transpose2d',
'conv_transpose3d', 'conv_transpose3d',
'conv_tbc', 'conv_tbc',
'prelu',
# BLAS # BLAS
'addmm', 'addmm',
...@@ -20,10 +24,8 @@ FP16_FUNCS = [ ...@@ -20,10 +24,8 @@ FP16_FUNCS = [
'matmul', 'matmul',
'mm', 'mm',
'mv', 'mv',
] ]
# TODO: ban in-place versions of these in fp16
FP32_FUNCS = [ FP32_FUNCS = [
# Pointwise # Pointwise
'acos', 'acos',
...@@ -54,15 +56,21 @@ FP32_FUNCS = [ ...@@ -54,15 +56,21 @@ FP32_FUNCS = [
'sum', 'sum',
'var', 'var',
# Special reduction-like BLAS
'addbmm',
'baddbmm',
'bmm',
# Misc # Misc
'renorm' 'renorm'
] ]
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion # Multi-tensor fns that may need type promotion
CASTS = [ CASTS = [
# Multi-tensor math # Multi-tensor math
...@@ -86,8 +94,9 @@ CASTS = [ ...@@ -86,8 +94,9 @@ CASTS = [
'ne' 'ne'
] ]
# Will possibly need to promote *all* elements of `seq` # Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [ SEQUENCE_CASTS = [
'cat', # torch.cat(seq, dim=0, out=None) 'cat',
'stack' # torch.stack(seq, dim=0, out=None) 'stack'
] ]
import contextlib import contextlib
import logging
import warnings import warnings
from .scaler import LossScaler, master_params from .scaler import LossScaler, master_params
from ._amp_state import maybe_print
import numpy as np import numpy as np
...@@ -71,8 +71,7 @@ class OptimWrapper(object): ...@@ -71,8 +71,7 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' + 'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.') 'optimizer wrapper.')
if any(self._skip_next): if any(self._skip_next):
logger = logging.getLogger('apex.amp') maybe_print('Gradient overflow, skipping update')
logger.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss self._skip_next = [False] * self._num_loss
else: else:
return self._optimizer.step(closure=closure) return self._optimizer.step(closure=closure)
......
import torch import torch
import logging
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product from itertools import product
# from apex_C import scale_check_overflow def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
...@@ -19,6 +16,21 @@ def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=F ...@@ -19,6 +16,21 @@ def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=F
master_grad.mul_(scale) master_grad.mul_(scale)
return False return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.to(master_grad.dtype)
stashed_grad.add_(scale, converted_model_grad)
master_grad.data = stashed_grad.data
return False
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_unscaling_non_fp32_grad = False warned_unscaling_non_fp32_grad = False
...@@ -44,121 +56,139 @@ class LossScaler(object): ...@@ -44,121 +56,139 @@ class LossScaler(object):
import amp_C import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
else: else:
if not LossScaler.warned_no_fused_kernel: if not LossScaler.warned_no_fused_kernel:
print("Warning: multi_tensor_applier fused unscale kernel is unavailable, " maybe_print(
"possibly because apex was installed without --cuda_ext --cpp_ext. " "Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"Using Python fallback. Original ImportError was: ", "possibly because apex was installed without --cuda_ext --cpp_ext. "
multi_tensor_applier.import_err) "Using Python fallback. Original ImportError was: " +
repr(multi_tensor_applier.import_err),
True)
LossScaler.has_fused_kernel = False LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True LossScaler.warned_no_fused_kernel = True
def loss_scale(self): def loss_scale(self):
return self._loss_scale return self._loss_scale
def unscale_grads_python(self, model_grads, master_grads, scale): def unscale_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads): for model, master in zip(model_grads, master_grads):
if model is not None: if model is not None:
if not LossScaler.warned_unscaling_non_fp32_grad: if not LossScaler.warned_unscaling_non_fp32_grad:
if master.type() != "torch.cuda.FloatTensor": if master.dtype != torch.float32:
logger = logging.getLogger("apex.amp") maybe_print(
logger.warning(
"Attempting to unscale a grad with type {} ".format(master.type()) + "Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. " "Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python( self._has_overflow = scale_check_overflow_python(model,
model, master,
1./scale, 1./scale,
master, self.dynamic)
self.dynamic)
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
break break
def clear_overflow_state(self): # unused_scale keeps some of the old API alive for hopefully a short time.
self._has_overflow = False def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False):
if self.has_fused_kernel:
self._overflow_buf.zero_()
def unscale(self, model_params, master_params, scale):
if self._has_overflow: if self._has_overflow:
return return
# Lots of defensive list processing going on here. Way more less efficient than scale = self._loss_scale
# consuming the iterator directly. Need to examine Python overhead.
model_master_params = [(model, master) for model, master if scale == 1.0 and models_are_masters and not self.dynamic:
in zip(model_params, master_params)] # some of these may be None return
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
# TODO: Make these lists permanent attributes of self, so they don't need to be created # if (not LossScaler.warned_unscaling_non_fp32_grad
# or garbage collected. Profiler shows that garbage collection overhead may be # and master_grads[0].dtype == torch.float16):
# substantial (200-300 usec). # print("Warning: unscaling grads that are not FP32. "
# This may be tricky because right now the lists need to be packed densely. # "Unscaling non-fp32 grads may indicate an error. "
# Maybe this could be handled within the multi_tensor_apply wrapper # "When using Amp, you don't need to call .half() on your model.")
# (allow some Tensors to be None using at::optional). # # Setting this to True unconditionally allows the possibility of an escape
src_dst_pairs = {torch.float16 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}, # # if never-before-seen non-fp32 grads are created in some later iteration.
torch.float32 : {torch.float16 : [[],[]], torch.float32 : [[],[]]}} # LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
for model, master in model_master_params: self._overflow_buf,
# Sync the None-ness of model and master params [model_grads, master_grads],
if model.grad is None and master.grad is not None: 1./scale)
master.grad = None else:
if model.grad is not None and master.grad is None: self.unscale_python(model_grads, master_grads, scale)
master.grad = torch.empty_like(master)
# Defer to update_scale
if model.grad is not None: # If the fused kernel is available, we only need one D2H memcopy and sync.
if model.grad is master.grad and scale == 1.0 and not self.dynamic: # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
continue # self._has_overflow = self._overflow_buf.item()
else:
src_dst_pairs[model.dtype][master.dtype][0].append(model.grad.data) def unscale_with_stashed_python(self,
src_dst_pairs[model.dtype][master.dtype][1].append(master.grad.data) model_grads,
stashed_master_grads,
assert len(src_dst_pairs[torch.float32][torch.float16][0]) == 0, "The loss scaler is "\ master_grads,
"being asked to unscale FP32 model gradients into FP16 master gradients. This is "\ scale):
"almost certainly an error." for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None:
for src, dst in product((torch.float16, torch.float32), continue
(torch.float16, torch.float32)): else:
if len(src_dst_pairs[src][dst][0]) > 0: if not LossScaler.warned_unscaling_non_fp32_grad:
if not LossScaler.warned_unscaling_non_fp32_grad and dst is torch.float16: if master.dtype != torch.float32:
print("Warning: unscaling grads that are not FP32. " maybe_print(
"Unscaling non-fp32 grads may indicate an error. " "Attempting to unscale a grad with type {} ".format(master.type()) +
"When using Amp, you don't need to call .half() on your model.") "Unscaling non-fp32 grads may indicate an error. "
# Setting this to True unconditionally allows the possibility of an escape "When using Amp, you don't need to call .half() on your model.")
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier( self._has_overflow = axpby_check_overflow_python(model,
LossScaler.multi_tensor_scale_cuda, stashed,
self._overflow_buf, master,
src_dst_pairs[src][dst], 1./scale,
1./scale) self.dynamic)
if self._has_overflow and self.dynamic:
break
def unscale_with_stashed(self,
model_grads,
stashed_master_grads,
master_grads):
if self._has_overflow:
return
scale = self._loss_scale
if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad
and master_grads[0].dtype == torch.float16):
print("Warning: unscaling grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
# Setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf,
[model_grads, stashed_master_grads, master_grads],
1./scale,
1.0,
0) # check only arg 0, aka the incoming model grads, for infs
else: else:
# Sync the None-ness of model and master params. self.unscale_with_stashed_python(model_grads,
all_same = True stashed_master_grads,
for model, master in model_master_params: master_grads,
if model.grad is None and master.grad is not None: scale)
master.grad = None
if model.grad is not None and master.grad is None:
master.grad = torch.empty_like(master)
if model.grad is not master.grad:
all_same = False
if scale == 1.0 and all_same and not self.dynamic:
return
# TODO: Make these lists permanent attributes of self, so they don't need to be created
# or garbage collected?
model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None]
master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None]
self.unscale_grads_python(model_grads, master_grads, scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() # self._has_overflow = self._overflow_buf.item()
def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
# Separate so unscale() can be called more that once before updating. # Separate so unscale() can be called more that once before updating.
def update_scale(self): def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
should_skip = True should_skip = True
self._loss_scale /= 2. self._loss_scale /= 2.
......
...@@ -5,6 +5,9 @@ import itertools ...@@ -5,6 +5,9 @@ import itertools
import torch import torch
def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))
def is_fp_tensor(x): def is_fp_tensor(x):
if is_nested(x): if is_nested(x):
# Fast-fail version of all(is_fp_tensor) # Fast-fail version of all(is_fp_tensor)
......
...@@ -4,6 +4,7 @@ from torch.autograd import Variable ...@@ -4,6 +4,7 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ..amp._amp_state import _amp_state, maybe_print
from ..amp.scaler import LossScaler from ..amp.scaler import LossScaler
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
...@@ -193,6 +194,8 @@ class FP16_Optimizer(object): ...@@ -193,6 +194,8 @@ class FP16_Optimizer(object):
self.multi_tensor_scale = amp_C.multi_tensor_scale self.multi_tensor_scale = amp_C.multi_tensor_scale
self._dummy_overflow_buf = torch.cuda.IntTensor([0]); self._dummy_overflow_buf = torch.cuda.IntTensor([0]);
# Having self.maybe_print distinct from _amp_state.maybe_print is another artifact
# of having to support FP16_Optimizer separately, for the time being.
def maybe_print(self, msg): def maybe_print(self, msg):
if self.verbose: if self.verbose:
print(msg) print(msg)
...@@ -401,8 +404,9 @@ class FP16_Optimizer(object): ...@@ -401,8 +404,9 @@ class FP16_Optimizer(object):
# self._update_scale(self.overflow) # self._update_scale(self.overflow)
if self.overflow: if self.overflow:
print("Gradient overflow. Skipping step, reducing " + # Using _amp_state.maybe_print instead of self.print here is intentional.
"loss scale to {}".format(self.loss_scaler.loss_scale())) maybe_print("Gradient overflow. Skipping step, reducing " +
"loss scale to {}".format(self.loss_scaler.loss_scale()))
return return
if closure is not None: if closure is not None:
...@@ -536,18 +540,37 @@ class FP16_Optimizer(object): ...@@ -536,18 +540,37 @@ class FP16_Optimizer(object):
if len(self.all_fp16_params) > 0: if len(self.all_fp16_params) > 0:
# print("Model grads before") # print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params]) # print([param.grad.data for param in self.all_fp16_params])
# I'm ONLY writing this as an incremental way to make some tests pass until
# I can refactor the tests as well.
# FP16_Optimizer should not be used by anyone.
model_grads = []
master_grads = []
for model_param, master_param in zip(self.all_fp16_params,
self.all_fp32_from_fp16_params):
if model_param.grad is not None:
model_grads.append(model_param.grad)
if master_param.grad is None:
master_param.grad = torch.empty_like(master_param)
master_grads.append(master_param.grad)
self.loss_scaler.unscale( self.loss_scaler.unscale(
self.all_fp16_params, model_grads,
self.all_fp32_from_fp16_params, master_grads,
self.loss_scaler.loss_scale()) self.loss_scaler.loss_scale())
# print("Master grads after") # print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp16_params]) # print([param.grad.data for param in self.all_fp32_from_fp16_params])
if len(self.all_fp32_from_fp32_params) > 0: if len(self.all_fp32_from_fp32_params) > 0:
model_grads = []
master_grads = []
for model_param, master_param in zip(self.all_fp32_from_fp32_params,
self.all_fp32_from_fp32_params):
if model_param.grad is not None:
model_grads.append(model_param.grad)
master_grads.append(master_param.grad)
# print("Model grads before") # print("Model grads before")
# print([param.grad.data for param in self.all_fp32_from_fp32_params]) # print([param.grad.data for param in self.all_fp32_from_fp32_params])
self.loss_scaler.unscale( self.loss_scaler.unscale(
self.all_fp32_from_fp32_params, model_grads,
self.all_fp32_from_fp32_params, master_grads,
self.loss_scaler.loss_scale()) self.loss_scaler.loss_scale())
# print("Master grads after") # print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp32_params]) # print([param.grad.data for param in self.all_fp32_from_fp32_params])
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import numbers import numbers
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F
import importlib import importlib
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
...@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -144,6 +145,9 @@ class FusedLayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
if not input.is_cuda:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine: if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)( return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias) input, self.weight, self.bias)
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
from itertools import chain from itertools import chain
import copy import copy
import importlib import importlib
from ..multi_tensor_apply import multi_tensor_applier
imported_flatten_impl = False imported_flatten_impl = False
...@@ -226,7 +227,13 @@ class DistributedDataParallel(Module): ...@@ -226,7 +227,13 @@ class DistributedDataParallel(Module):
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._overflow_buf = torch.cuda.IntTensor([0])
self.create_hooks() self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
...@@ -396,8 +403,15 @@ class DistributedDataParallel(Module): ...@@ -396,8 +403,15 @@ class DistributedDataParallel(Module):
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else: else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)): if multi_tensor_applier.available:
buf.copy_(synced) multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
......
...@@ -67,7 +67,10 @@ class SyncBatchNorm(_BatchNorm): ...@@ -67,7 +67,10 @@ class SyncBatchNorm(_BatchNorm):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input):
if not self.training and self.track_running_stats and not self.channel_last: # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -78,4 +81,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -78,4 +81,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last) return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last)
...@@ -22,10 +22,13 @@ class SyncBatchnormFunction(Function): ...@@ -22,10 +22,13 @@ class SyncBatchnormFunction(Function):
if channel_last: if channel_last:
count = int(input.numel()/input.size(-1)) count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input) mean, var_biased = syncbn.welford_mean_var_c_last(input)
else : else:
count = int(input.numel()/input.size(1)) count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input) mean, var_biased = syncbn.welford_mean_var(input)
if count == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if not process_group: if not process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
...@@ -48,7 +51,7 @@ class SyncBatchnormFunction(Function): ...@@ -48,7 +51,7 @@ class SyncBatchnormFunction(Function):
running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
else: else:
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_var.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std) ctx.save_for_backward(input, weight, mean, inv_std)
ctx.process_group = process_group ctx.process_group = process_group
......
...@@ -72,10 +72,9 @@ class SyncBatchNorm(_BatchNorm): ...@@ -72,10 +72,9 @@ class SyncBatchNorm(_BatchNorm):
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
process_group = self.process_group process_group = self.process_group
world_size = 0 world_size = 1
if not self.process_group: if not self.process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
self.num_batches_tracked += 1 self.num_batches_tracked += 1
with torch.no_grad(): with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous() channel_first_input = input.transpose(0, 1).contiguous()
...@@ -88,6 +87,7 @@ class SyncBatchNorm(_BatchNorm): ...@@ -88,6 +87,7 @@ class SyncBatchNorm(_BatchNorm):
local_sqr_mean = torch.pow( local_sqr_mean = torch.pow(
squashed_input_tensor_view, 2).mean(1) squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size(process_group)
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_mean, ReduceOp.SUM, process_group) local_mean, ReduceOp.SUM, process_group)
mean = local_mean / world_size mean = local_mean / world_size
......
...@@ -18,37 +18,26 @@ void multi_tensor_sgd_cuda( ...@@ -18,37 +18,26 @@ void multi_tensor_sgd_cuda(
bool first_run, bool first_run,
bool wd_after_momentum); bool wd_after_momentum);
void scale_check_overflow_cuda( void multi_tensor_axpby_cuda(
const at::Tensor& grads, int chunk_size,
float scale, at::Tensor noop_flag,
const at::Tensor& d_buf, std::vector<std::vector<at::Tensor>> tensor_lists,
const at::Tensor& downscaled_grads); float a,
float b,
void scale_check_overflow( int arg_to_check);
at::Tensor grads,
float scale,
at::Tensor overflow_buf,
at::Tensor downscaled_grads)
// const at::optional<at::Tensor> downscaled_grads)
{
AT_CHECK(grads.type().is_cuda(), "grads must be a CUDA tensor");
AT_CHECK(grads.is_contiguous(), "grads must be contiguous");
AT_CHECK(overflow_buf.type().is_cuda(), "overflow_buf must be a CUDA tensor");
AT_CHECK(overflow_buf.is_contiguous(), "overflow_buf must be contiguous");
AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
// Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads); at::Tensor multi_tensor_l2norm_cuda(
} int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors"); "Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment