"...text-generation-inference.git" did not exist on "842f6658e286f21f0212ae36ad2114125d81d46a"
Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
import contextlib
import warnings
import sys
import torch
from . import utils
from .opt import OptimWrapper
from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print
if torch.distributed.is_available():
from ..parallel.LARC import LARC
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager
def scale_loss(loss,
optimizers,
loss_id=0,
model=None,
delay_unscale=False,
delay_overflow_check=False):
"""
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
and unscaled, so that ``optimizer.step()`` can be called.
.. note::
If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
any FP16 gradients are copied to FP32 master gradients before being unscaled.
``optimizer.step()`` will then apply the unscaled master gradients to the master params.
.. warning::
If Amp is using explicit FP32 master params, only the FP32 master gradients will be
unscaled. The direct ``.grad`` attributes of any FP16
model params will remain scaled after context manager exit.
This subtlety affects gradient clipping. See "Gradient clipping" under
`Advanced Amp Usage`_ for best practices.
Args:
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
manager yields is simply ``loss.float()*loss_scale``, so in principle
``loss`` could have more than one element, as long as you call
``backward()`` on ``scaled_loss`` appropriately within the context manager body.
optimizers: All optimizer(s) for which the current backward pass is creating gradients.
Must be an optimizer or list of optimizers returned from an earlier call
to ``amp.initialize``. For example use with multiple optimizers, see
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
being used for the current backward pass. See "Multiple models/optimizers/losses"
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations.
delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
the default value of ``False`` is strongly recommended.
If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit.
``delay_unscale=True`` is a minor ninja performance optimization and can result
in weird gotchas (especially with multiple models/optimizers/losses),
so only use it if you know what you're doing.
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
illustrates a situation where this CAN (but does not need to) be used.
.. warning::
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
called yet after context manager exit, and must wait for another, later backward context
manager invocation with ``delay_unscale`` left to False.
.. _`Advanced Amp Usage`:
https://nvidia.github.io/apex/advanced.html
"""
if not hasattr(_amp_state, "opt_properties"):
raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
"before `with amp.scale_loss`.")
if not _amp_state.opt_properties.enabled:
yield loss
return
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]
loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights)
and (not loss_scaler.dynamic)
and loss_scale == 1.0):
yield loss.float()
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
return
if not delay_unscale:
if isinstance(optimizers, list):
for optimizer in optimizers:
if not optimizer._amp_stash.params_have_scaled_gradients:
optimizer._prepare_amp_backward()
yield (loss.float())*loss_scale
if delay_unscale:
for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True
else:
# FusedSGD may take care of unscaling as part of their step() methods.
# if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state()
for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler)
optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = False if delay_overflow_check else loss_scaler.update_scale()
if should_skip:
for optimizer in optimizers:
if not optimizer._amp_stash.already_patched:
# Close on loss_scaler and loss_id as well, to be safe. Probably not
# necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step
def skip_step(closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None
if hasattr(opt, "most_recent_scale"):
opt.most_recent_scale = 1.0
opt.scale_set_by_backward = False
opt.step = opt_step
opt._amp_stash.already_patched = False
return skip_step
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
optimizer._amp_stash.already_patched = True
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
# Free function version of AmpHandle.disable_casts, another step on the
# path to removing the concept of "AmpHandle"
@contextlib.contextmanager
def disable_casts():
_amp_state.handle._is_active = False
yield
_amp_state.handle._is_active = True
class AmpHandle(object):
def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
self._enable_caching = enable_caching
self._verbose = verbose
self._cache = dict()
self._default_scaler = LossScaler(loss_scale)
self._is_active = True
self._all_wrappers = []
def is_active(self):
return self._is_active
@contextlib.contextmanager
def _disable_casts(self):
self._is_active = False
yield
self._is_active = True
def wrap_optimizer(self, optimizer, num_loss=1):
self._default_scaler = None
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
"documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
if not self.is_active():
yield loss
return
if self._default_scaler is None:
raise RuntimeError(
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
self._default_scaler.clear_overflow_state()
self._default_scaler.unscale(
master_params(optimizer),
master_params(optimizer),
loss_scale)
should_skip = self._default_scaler.update_scale()
if should_skip:
optimizer_step = optimizer.step
def skip_step():
maybe_print('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
self._clear_cache()
def _clear_cache(self):
self._cache.clear()
# Experimental support for saving / restoring uncasted versions of functions
def _save_func(self, mod, fn, func):
self._all_wrappers.append((mod, fn, func))
def _deactivate(self):
for mod, fn, func in self._all_wrappers:
utils.set_func(mod, fn, func)
self._all_wrappers = []
@property
def has_cache(self):
return self._enable_caching
@property
def cache(self):
return self._cache
def remove_cache(self, param):
if self.has_cache and param in self.cache:
del self.cache[param]
@property
def verbose(self):
return self._verbose
class NoOpHandle(object):
def is_active(self):
return False
@contextlib.contextmanager
def _disable_casts(self):
yield
def wrap_optimizer(self, optimizer, num_loss=1):
return OptimWrapper(optimizer, self, num_loss)
@contextlib.contextmanager
def scale_loss(self, loss, optimizer):
yield loss
@property
def has_cache(self):
return False
@property
def verbose(self):
return False
def _clear_cache(self):
pass
def _deactivate(self):
pass
# TODO: think about the following two. They do weird things.
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
# - torch.nn.utils.weight_norm
# Notes:
# F.instance_norm uses batch_norm internally. Which correctly handles
# fp16 in/out with fp32 weights. So we shouldn't do anything for
# either of these.
# F.normalize calls `input.norm()` internally, so it's redundant, but
# kept here in case impl. changes.
# F.cosine_similarity is same: calls `x.norm()` internally.
import torch.nn.functional
MODULE = torch.nn.functional
FP16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
BFLOAT16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [
# Interpolation/Upsampling TODO: Remove for 1.2
'interpolate',
'grid_sample',
# Pointwise
'softplus',
'softmin',
'log_softmax',
'softmax',
'gelu',
# Normalization
'layer_norm',
'group_norm',
'local_response_norm',
'normalize',
'cosine_similarity',
# Loss functions
# TODO: which of these can be fp16?
'poisson_nll_loss',
'cosine_embedding_loss',
'cross_entropy',
'hinge_embedding_loss',
'kl_div',
'l1_loss',
'mse_loss',
'margin_ranking_loss',
'multilabel_margin_loss',
'multilabel_soft_margin_loss',
'multi_margin_loss',
'nll_loss',
'binary_cross_entropy_with_logits',
'smooth_l1_loss',
'soft_margin_loss',
'triplet_margin_loss',
'ctc_loss'
]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
from .. import compat
from . import torch_overrides
import importlib
import torch
# if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
# else:
# MODULE = torch.autograd.Variable
FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
])
BFLOAT16_FUNCS = [
'__matmul__',
]
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
])
CASTS = compat.filter_attrs(MODULE, [
'__add__',
'__div__',
'__eq__',
'__ge__',
'__gt__',
'__iadd__',
'__idiv__',
'__imul__',
'__isub__',
'__itruediv__',
'__le__',
'__lt__',
'__mul__',
'__ne__',
'__radd__',
'__rdiv__',
'__rmul__',
'__rsub__',
'__rtruediv__',
'__sub__',
'__truediv__',
])
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
# We need to grab all the methods from torch_overrides and add them to
# the Tensor lists as well, as almost all methods are duplicated
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn):
lst.append(fn)
import torch
from .. import utils
MODULE = torch
FP16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
'prelu',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
BFLOAT16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
FP32_FUNCS = [
# Pointwise
'acos',
'asin',
'cosh',
'erfinv',
'exp',
'expm1',
'log',
'log10',
'log2',
'reciprocal',
'rsqrt',
'sinh',
'tan',
# Other math
'pow',
# Reduction
'cumprod',
'cumsum',
'dist',
# 'mean',
'norm',
'prod',
'std',
'sum',
'var',
# Misc
'renorm'
]
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
FP32_FUNCS.append('mean')
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.is_cuda_enabled():
# workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion
CASTS = [
# Multi-tensor math
'addcdiv',
'addcmul',
'atan2',
'cross',
'bilinear',
'dot',
# Element-wise _or_ tensor-wise math
'add',
'div',
'mul',
# Comparison
'eq',
'equal',
'ge',
'gt',
'le',
'lt',
'ne'
]
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [
'cat',
'stack'
]
import contextlib
import warnings
from .scaler import LossScaler, master_params
from ._amp_state import maybe_print
import numpy as np
class OptimWrapper(object):
def __init__(self, optimizer, amp_handle, num_loss):
self._optimizer = optimizer
self._amp_handle = amp_handle
self._num_loss = num_loss
self._loss_idx = 0
self._skip_next = [False] * num_loss
self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
@contextlib.contextmanager
def scale_loss(self, loss):
if not self._amp_handle.is_active():
yield loss
return
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads = []
if self._loss_idx > 0:
for p in master_params(self._optimizer):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
self._cur_loss_scaler().clear_overflow_state()
self._cur_loss_scaler().unscale(
master_params(self._optimizer),
master_params(self._optimizer),
loss_scale)
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
self._loss_idx += 1
if len(cached_grads) > 0:
for p, cached_grad in zip(master_params(self._optimizer),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
return self._loss_scaler[self._loss_idx]
def step(self, closure=None):
if not self._amp_handle.is_active():
return self._optimizer.step(closure=closure)
self._loss_idx = 0
for group in self._optimizer.param_groups:
for p in group['params']:
self._amp_handle.remove_cache(p)
if closure is not None:
raise NotImplementedError(
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
maybe_print('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self):
return self._optimizer.__setstate__()
def __repr__(self):
return self._optimizer.__repr__()
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
def zero_grad(self):
return self._optimizer.zero_grad()
def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
from . import utils, wrap
import torch
_VF = torch._C._VariableFunctions
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
def _gen_VF_wrapper(name):
def wrapper(*args, **kwargs):
return getattr(_VF, name)(*args, **kwargs)
return wrapper
# Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version.
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
# imported at module scope within torch.nn.modules.rnn). This should
# not affect third-party importers of _VF.py.
class VariableFunctionsShim(object):
def __init__(self):
for name in RNN_NAMES:
for suffix in ['', '_cell']:
fn_name = name + suffix
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
def has_old_rnns():
try:
torch.nn.backends.thnn.backend.LSTMCell
return True
except:
return False
def whitelist_rnn_cells(cast_fn, handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
mod = torch.nn.backends.thnn.backend
else:
fn_names = [x + '_cell' for x in RNN_NAMES]
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, VariableFunctionsShim)
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, cast_fn, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
# Special handling of `backward` for fused gru / lstm:
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
import torch
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
if master_grad is not model_grad: # copy_ probably internally short-circuits this
if model_grad.is_sparse:
master_grad.copy_(model_grad.to_dense())
else:
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
if model_grad.is_sparse:
cpu_sum = float(model_grad.float()._values().sum())
else:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.data.to(master_grad.dtype)
master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
return False
class LossScaler(object):
warned_no_fused_kernel = False
warned_unscaling_non_fp32_grad = False
has_fused_kernel = False
def __init__(self,
loss_scale,
init_scale=2.**16,
scale_factor=2.,
scale_window=2000,
min_loss_scale=None,
max_loss_scale=2.**24):
if loss_scale == "dynamic":
self.dynamic = True
self._loss_scale = min(max_loss_scale, init_scale)
else:
self.dynamic = False
self._loss_scale = loss_scale
self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._scale_seq_len = scale_window
self._unskipped = 0
self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
else:
if not LossScaler.warned_no_fused_kernel:
maybe_print(
"Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"possibly because apex was installed without --cuda_ext --cpp_ext. "
"Using Python fallback. Original ImportError was: " +
repr(multi_tensor_applier.import_err),
True)
LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True
def loss_scale(self):
return self._loss_scale
def unscale_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads):
if model is not None:
if not LossScaler.warned_unscaling_non_fp32_grad:
if master.dtype != torch.float32:
maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python(model,
master,
1./scale,
self.dynamic)
if self._has_overflow and self.dynamic:
break
# unused_scale keeps some of the old API alive for hopefully a short time.
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
if self._has_overflow:
return
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if scale == 1.0 and models_are_masters and not self.dynamic:
return
if LossScaler.has_fused_kernel:
# if (not LossScaler.warned_unscaling_non_fp32_grad
# and master_grads[0].dtype == torch.float16):
# print("Warning: unscaling grads that are not FP32. "
# "Unscaling non-fp32 grads may indicate an error. "
# "When using Amp, you don't need to call .half() on your model.")
# # Setting this to True unconditionally allows the possibility of an escape
# # if never-before-seen non-fp32 grads are created in some later iteration.
# LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
self._overflow_buf,
[model_grads, master_grads],
1./scale)
else:
self.unscale_python(model_grads, master_grads, scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
# self._has_overflow = self._overflow_buf.item()
def unscale_with_stashed_python(self,
model_grads,
stashed_master_grads,
master_grads,
a,
b):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None:
continue
else:
if not LossScaler.warned_unscaling_non_fp32_grad:
if master.dtype != torch.float32:
maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = axpby_check_overflow_python(model,
stashed,
master,
a,
b,
self.dynamic)
if self._has_overflow and self.dynamic:
break
def unscale_with_stashed(self,
model_grads,
stashed_master_grads,
master_grads,
scale_override=None):
if self._has_overflow:
return
grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad
and master_grads[0].dtype == torch.float16):
print("Warning: unscaling grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
# Setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf,
[model_grads, stashed_master_grads, master_grads],
out_scale/grads_have_scale, # 1./scale,
out_scale/stashed_have_scale, # 1.0,
0) # check only arg 0, aka the incoming model grads, for infs
else:
self.unscale_with_stashed_python(model_grads,
stashed_master_grads,
master_grads,
out_scale/grads_have_scale,
out_scale/stashed_have_scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
# self._has_overflow = self._overflow_buf.item()
def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()
if self._has_overflow and self.dynamic:
should_skip = True
if(self._min_loss_scale):
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
else:
self._loss_scale = self._loss_scale/2.
self._unskipped = 0
else:
should_skip = False
self._unskipped += 1
if self._unskipped == self._scale_seq_len and self.dynamic:
self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
self._unskipped = 0
return should_skip
from . import compat
import functools
import itertools
import torch
def is_cuda_enabled():
return torch.version.cuda is not None
def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))
def is_fp_tensor(x):
if is_nested(x):
# Fast-fail version of all(is_fp_tensor)
for y in x:
if not is_fp_tensor(y):
return False
return True
return compat.is_tensor_like(x) and compat.is_floating_point(x)
def is_nested(x):
return isinstance(x, tuple) or isinstance(x, list)
def should_cache(x):
if is_nested(x):
# Fast-fail version of all(should_cache)
for y in x:
if not should_cache(y):
return False
return True
return isinstance(x, torch.nn.parameter.Parameter) and \
type_string(x) == 'FloatTensor'
def collect_fp_tensor_types(args, kwargs):
def collect_types(x, types):
if is_nested(x):
for y in x:
collect_types(y, types)
else:
types.add(type_string(x))
all_args = itertools.chain(args, kwargs.values())
types = set()
for x in all_args:
if is_fp_tensor(x):
collect_types(x, types)
return types
def type_string(x):
return x.type().split('.')[-1]
def maybe_half(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_half(y) for y in x])
if not x.is_cuda or type_string(x) == 'HalfTensor':
return x
else:
if verbose:
print('Float->Half ({})'.format(name))
return x.half()
def maybe_bfloat16(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_bfloat16(y) for y in x])
if not x.is_cuda or type_string(x) == 'BFloat16Tensor':
return x
else:
if verbose:
print('Float->BFloat16 ({})'.format(name))
return x.bfloat16()
def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
if not x.is_cuda or type_string(x) == 'FloatTensor':
return x
else:
if verbose:
print('Half->Float ({})'.format(name))
return x.float()
# NB: returneds casted `args`, mutates `kwargs` in-place
def casted_args(cast_fn, args, kwargs):
new_args = []
for x in args:
if is_fp_tensor(x):
new_args.append(cast_fn(x))
else:
new_args.append(x)
for k in kwargs:
val = kwargs[k]
if is_fp_tensor(val):
kwargs[k] = cast_fn(val)
return new_args
def cached_cast(cast_fn, x, cache):
if is_nested(x):
return type(x)([cached_cast(y) for y in x])
if x in cache:
cached_x = cache[x]
next_functions_available = False
if x.requires_grad and cached_x.requires_grad:
if len(cached_x.grad_fn.next_functions) > 1:
next_functions_available = True
# Make sure x is actually cached_x's autograd parent.
if next_functions_available and cached_x.grad_fn.next_functions[1][0].variable is not x:
raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.")
# During eval, it's possible to end up caching casted weights with
# requires_grad=False. On the next training iter, if cached_x is found
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x]
elif x.requires_grad and cached_x.requires_grad and not next_functions_available:
del cache[x]
else:
return cached_x
casted_x = cast_fn(x)
cache[x] = casted_x
return casted_x
def verbosify(cast_fn, fn_name, verbose):
if verbose:
return functools.partial(cast_fn, name=fn_name, verbose=verbose)
else:
return cast_fn
def as_inplace(fns):
for x in fns:
yield x + '_'
def has_func(mod, fn):
if isinstance(mod, dict):
return fn in mod
else:
return hasattr(mod, fn)
def get_func(mod, fn):
if isinstance(mod, dict):
return mod[fn]
else:
return getattr(mod, fn)
def set_func(mod, fn, new_fn):
if isinstance(mod, dict):
mod[fn] = new_fn
else:
setattr(mod, fn, new_fn)
def set_func_save(handle, mod, fn, new_fn):
cur_fn = get_func(mod, fn)
handle._save_func(mod, fn, cur_fn)
set_func(mod, fn, new_fn)
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
# to this forward call, not the flat buffer.
# - The ordering of weights in the flat buffer is...idiosyncratic.
# First problem is solved with combination of set_ (to set up
# correct storage) and copy_ (so the fp16 weight derives from the
# fp32 one in autograd.
# Second is solved by doing ptr arithmetic on the fp32 weights
# to derive the correct offset.
#
# TODO: maybe this should actually use
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
# on first iter and cache the right offsets. Ugh.
def synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0][0].data_ptr()
for layer_weights in fp32_weights:
fp16_layer_weights = []
for w_fp32 in layer_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_layer_weights.append(w_fp16)
fp16_weights.append(fp16_layer_weights)
return fp16_weights
def _str_from_dtype(dtype=torch.float16):
type_to_str = {torch.float16 : 'Half',
torch.bfloat16 : 'BFloat16'}
return type_to_str[dtype]
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
dtype=torch.float16,
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().to(dtype=dtype)
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
from . import compat
from . import utils
from ._amp_state import _amp_state
from . import rnn_compat
import functools
import torch
def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not handle.is_active():
return orig_fn(*args, **kwargs)
if try_caching and handle.has_cache:
args = list(args)
for i in range(len(args)):
if utils.should_cache(args[i]):
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
for k in kwargs:
if utils.should_cache(kwargs[k]):
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
# is on the new API and I am free to get rid of handle, I can clean this up.
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor'])
or types == set(['BFloat16Tensor', 'FloatTensor'])):
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
else:
raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}'
.format(types))
return wrapper
def promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func_save(handle, mod, fn, wrapper)
def sequence_promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(seq, *args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(seq, *args, **kwargs)
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif (types == set(['HalfTensor', 'FloatTensor']) or
types == set(['BFloat16Tensor', 'FloatTensor'])):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
else:
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def promote_match_arg0(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if not _amp_state.handle.is_active():
return orig_fn(arg0, *args, **kwargs)
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
if utils.type_string(arg0) == 'BFloat16Tensor':
cast_fn = utils.maybe_bfloat16
elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float
else:
return orig_fn(arg0, *args, **kwargs)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types or 'BFloat16Tensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 or bfloat16 args.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_arg0_half(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}:
raise NotImplementedError('Cannot call in-place method ' +
'{} with fp16 or bfloat16 args.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
# - Will call into either CudnnRNN or AutogradRNN
# - Each of these are factory functions that return a per-iter
# `forward` function
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
flat_weight = kwargs.get('flat_weight')
if flat_weight is not None:
# We replace `flat_weight` with an uninitialized fp16
# Tensor. The "actual" weight tensors (provided in `forward`),
# will then be set up as ptrs into the buffer and have the
# corresponding fp32 values copied in.
# We need to call `copy` on the "actual" weights so that the
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert utils.type_string(flat_weight) == 'FloatTensor'
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
else:
flat_weight_fp16 = torch.empty_like(flat_weight,
dtype=torch.float16)
kwargs['flat_weight'] = flat_weight_fp16
else:
flat_weight_fp16 = None
forward = orig_rnn(*args, **kwargs)
@functools.wraps(forward)
def fwd_wrapper(*fargs, **fkwargs):
assert len(fargs) == 3 or len(fargs) == 4
inputs, weights, hiddens = fargs[:3]
assert utils.is_fp_tensor(inputs)
assert isinstance(weights, list)
cast_fn = utils.verbosify(utils.maybe_half,
fn,
verbose)
new_args = []
# 0) Inputs
new_args.append(cast_fn(inputs))
# 1) Weights
if flat_weight_fp16 is not None:
fp16_weights = utils.synthesize_flattened_rnn_weights(
weights, flat_weight_fp16, fn, verbose)
else:
fp16_weights = [[cast_fn(w) for w in layer]
for layer in weights]
new_args.append(fp16_weights)
# 2) Inputs: either a tuple (for LSTM) or single tensor
if isinstance(hiddens, tuple):
new_args.append(tuple(cast_fn(x) for x in hiddens))
elif utils.is_fp_tensor(hiddens):
new_args.append(cast_fn(hiddens))
else:
# Hiddens can, in principle, be `None` -- pass through
new_args.append(hiddens)
# 3) Batch sizes (0.4 or later only)
if len(fargs) == 4:
new_args.append(fargs[3])
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, cast_fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
mod = torch.nn.modules.rnn._rnn_impls
else:
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
assert len(args) == 9
assert len(kwargs) == 0
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
params_idx = 3 # PackedSequence case
if cast_fn == utils.maybe_half:
dtype = torch.half
elif cast_fn == utils.maybe_bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16")
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=dtype)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, dtype, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
else:
new_args.append(arg)
return orig_fn(*new_args)
utils.set_func_save(handle, mod, fn, wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
This diff is collapsed.
import torch
from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck
from apex.contrib.bottleneck import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
from apex.contrib.peer_memory import PeerMemoryPool
def ground_truth_bottleneck(C, dtype, explicit_nhwc):
bottleneck = Bottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc)
bottleneck.to(dtype=dtype, device='cuda')
for p in bottleneck.parameters():
torch.distributed.broadcast(p, 0)
for b in bottleneck.buffers():
torch.distributed.broadcast(b, 0)
return bottleneck
def print_bottleneck_p_and_b(bottleneck):
with torch.no_grad():
for n,p in bottleneck.named_parameters():
print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
for n,p in bottleneck.named_buffers():
print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32))))
def has_nan(x):
if isinstance(x, list) or isinstance(x, tuple):
for xx in x:
if torch.any(torch.isnan(xx)):
return True
return False
elif isinstance(x, dict):
for k,v in x.items():
if torch.any(torch.isnan(v)):
return True
else:
return torch.any(torch.isnan(x))
def rel_diff_t(xx1, xx2):
return ((xx1 - xx2).norm(p=2,dtype=torch.float32) / (xx1 + xx2).norm(p=2,dtype=torch.float32)).item()
def rel_diff(x1, x2):
if isinstance(x1, list) or isinstance(x1, tuple):
return [rel_diff_t(xx1,xx2) for xx1,xx2 in zip(x1,x2)]
elif isinstance(x1, dict):
return [rel_diff_t(xx1, xx2) for (k1,xx1), (k2,xx2) in zip(x1.items(),x2.items())]
else:
return rel_diff_t(x1,x2)
def graph_it(bottleneck, x):
print("Graphing")
with torch.no_grad():
x = x.clone()
x.grad = None
x.requires_grad = True
return torch.cuda.make_graphed_callables(bottleneck, (x,))
def clone_inputs(bottleneck, x, dy=None):
with torch.no_grad():
x = x.clone()
x.grad = None
x.requires_grad = True
if dy is None:
y = bottleneck(x)
dy = torch.randn_like(y) / 1e2
torch.distributed.broadcast(dy, 0)
return x, dy
def fprop_and_bprop(bottleneck, x, dy):
y = bottleneck(x)
y.backward(dy)
dgrad = x.grad.detach()
wgrad = {}
for n,p in bottleneck.named_parameters():
wgrad[n] = p.grad.detach()
return x, y, dy, dgrad, wgrad
def ground_truth(N, C, H, W, dtype, memory_format, bottleneck):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
with torch.no_grad():
x = torch.randn([N,H,W,C], dtype=dtype, device='cuda')
torch.distributed.broadcast(x, 0)
x, dy = clone_inputs(bottleneck, x)
return fprop_and_bprop(bottleneck, x, dy)
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
assert(False), "Not implemented yet"
def print_ground_truth(gt):
x, y, dy, dgrad, wgrad = gt
if has_nan(y) or has_nan(dgrad) or has_nan(wgrad):
print("Error! Ground truth has NAN")
else:
print("Ok! No NAN found in ground truth")
def apply_to_different_bottleneck(gt, bottleneck):
with torch.no_grad():
x, _, dy, _, _ = gt
x, dy = clone_inputs(bottleneck, x, dy)
return fprop_and_bprop(bottleneck, x, dy)
def compare_single_field(results, f1, f2, l0, l1, l2):
if has_nan(f1) and has_nan(f2):
results[l0] = "both NAN"
elif has_nan(f1):
results[l0] = "%s.%s NAN" % (l1, l0)
elif has_nan(f2):
results[l0] = "%s.%s NAN" % (l2, l0)
else:
results[l0] = "%s" % (str(rel_diff(f1,f2)))
def compare(gt, bt):
x1, y1, dy1, dgrad1, wgrad1 = gt
x2, y2, dy2, dgrad2, wgrad2 = bt
results = {}
compare_single_field(results, y1, y2, "y", "gt", "bt")
compare_single_field(results, dy1, dy2, "dy", "gt", "bt")
compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt")
compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt")
for i in range(torch.distributed.get_world_size()):
if i == torch.distributed.get_rank():
print(i,results)
torch.distributed.barrier()
def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args):
spatial_bottleneck = SpatialBottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc,spatial_parallel_args=spatial_parallel_args)
spatial_bottleneck.to(dtype=dtype, device='cuda')
with torch.no_grad():
sp = {}
for n,p in spatial_bottleneck.named_parameters():
sp[n] = p
for n,p in gt_bottleneck.named_parameters():
sp[n].copy_(p)
sb = {}
for n,b in spatial_bottleneck.named_buffers():
sb[n] = b
for n,b in gt_bottleneck.named_buffers():
sb[n].copy_(b)
return spatial_bottleneck
def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False):
assert(explicit_nhwc), "Only tested for explicit nhwc"
x, _, dy, _, _ = gt
N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel
dtype = x.dtype
spatial_group_size = world_size
spatial_group_rank = rank
spatial_communicator = None
spatial_halo_exchanger = halex
spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x
use_delay_kernel = False
spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel)
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args)
with torch.no_grad():
Hs = H // spatial_group_size
xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone()
xs.requires_grad = True
spatial_bottleneck = graph_it(spatial_bottleneck, xs)
_, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys)
# gather output pieces
for n,p in wgrad.items():
if fp32_reduce:
p32 = p.float()
torch.distributed.all_reduce(p32)
p.copy_(p32.half())
else:
torch.distributed.all_reduce(p)
ys = [torch.empty_like(y) for _ in range(spatial_group_size)]
torch.distributed.all_gather(ys,y)
y = torch.cat(ys,dim=1)
dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)]
torch.distributed.all_gather(dgrads,dgrad)
dgrad = torch.cat(dgrads,dim=1)
return x, y, dy, dgrad, wgrad
def main():
torch.use_deterministic_algorithms(True)
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
explicit_nhwc = True
dtype = torch.float16
N, C, H, W = 1, 64, 200, 336
Hs = ((H+8*world_size-1) // (8*world_size)) * 8
H = Hs*world_size
gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc)
gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck)
# verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck
spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None)
bt = apply_to_different_bottleneck(gt, spatial_bottleneck)
compare(gt, bt)
#print_bottleneck_p_and_b(gt_bottleneck)
#print_bottleneck_p_and_b(spatial_bottleneck)
group_size = world_size
group = rank // group_size
ranks = [group*group_size+i for i in range(group_size)]
rank_in_group = rank % group_size
spatial_group_size = world_size
spatial_communicator = None
peer_pool = PeerMemoryPool(64*1024*1024, 2*1024*1024, ranks)
#class HaloExchangerNoComm(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerAllGather(HaloExchanger):
# def __init__(self, ranks, rank_in_group, comm):
#class HaloExchangerSendRecv(HaloExchanger):
# def __init__(self, ranks, rank_in_group):
#class HaloExchangerPeer(HaloExchanger):
# def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1):
#halex = HaloExchangerAllGather(ranks, rank_in_group)
#halex = HaloExchangerSendRecv(ranks, rank_in_group)
halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1)
#print("halex.signals = %s" % (str(halex.signals)))
# Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding
#torch.cuda.synchronize()
#torch.distributed.barrier()
bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True)
compare(gt, bt2)
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
from .clip_grad import clip_grad_norm_
This diff is collapsed.
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment