Commit 1811808c authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

add new files

parent c2b62b7f
# TODO: think about the following two. They do weird things.
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
# - torch.nn.utils.weight_norm
# Notes:
# F.instance_norm uses batch_norm internally. Which correctly handles
# fp16 in/out with fp32 weights. So we shouldn't do anything for
# either of these.
# F.normalize calls `input.norm()` internally, so it's redundant, but
# kept here in case impl. changes.
# F.cosine_similarity is same: calls `x.norm()` internally.
import torch.nn.functional
MODULE = torch.nn.functional
FP16_FUNCS = [
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc', # Undocumented / maybe new?
'linear',
]
FP32_FUNCS = [
# Interpolation/Upsampling TODO: Remove for 1.2
'interpolate',
'grid_sample',
# Pointwise
'softplus',
'softmin',
'log_softmax',
'softmax',
'gelu',
# Normalization
'layer_norm',
'group_norm',
'local_response_norm',
'normalize',
'cosine_similarity',
# Loss functions
# TODO: which of these can be fp16?
'poisson_nll_loss',
'cosine_embedding_loss',
'cross_entropy',
'hinge_embedding_loss',
'kl_div',
'l1_loss',
'mse_loss',
'margin_ranking_loss',
'multilabel_margin_loss',
'multilabel_soft_margin_loss',
'multi_margin_loss',
'nll_loss',
'binary_cross_entropy_with_logits',
'smooth_l1_loss',
'soft_margin_loss',
'triplet_margin_loss',
'ctc_loss'
]
BANNED_FUNCS = [
('binary_cross_entropy',
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
"It requires that the output of the previous function be already a FloatTensor. \n\n"
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
"that is compatible with amp.\nAnother option is to add\n"
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
"If you _really_ know what you are doing, you can disable this warning by passing "
"allow_banned=True to `amp.init()`."))
]
from .. import compat
from . import torch_overrides
import importlib
import torch
# if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor
# else:
# MODULE = torch.autograd.Variable
FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
])
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
])
CASTS = compat.filter_attrs(MODULE, [
'__add__',
'__div__',
'__eq__',
'__ge__',
'__gt__',
'__iadd__',
'__idiv__',
'__imul__',
'__isub__',
'__itruediv__',
'__le__',
'__lt__',
'__mul__',
'__ne__',
'__radd__',
'__rdiv__',
'__rmul__',
'__rsub__',
'__rtruediv__',
'__sub__',
'__truediv__',
])
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
# We need to grab all the methods from torch_overrides and add them to
# the Tensor lists as well, as almost all methods are duplicated
# between `torch` and `torch.Tensor` (and check with `hasattr`,
# because a few random ones aren't defined on Tensor)
_self_mod = importlib.import_module(__name__)
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
lst = getattr(_self_mod, attrname)
for fn in getattr(torch_overrides, attrname):
if hasattr(MODULE, fn):
lst.append(fn)
import torch
from .. import utils
MODULE = torch
FP16_FUNCS = [
# Low level functions wrapped by torch.nn layers.
# The wrapper layers contain the weights which are then passed in as a parameter
# to these functions.
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d',
'conv_tbc',
'prelu',
# BLAS
'addmm',
'addmv',
'addr',
'matmul',
'mm',
'mv',
]
FP32_FUNCS = [
# Pointwise
'acos',
'asin',
'cosh',
'erfinv',
'exp',
'expm1',
'log',
'log10',
'log2',
'reciprocal',
'rsqrt',
'sinh',
'tan',
# Other math
'pow',
# Reduction
'cumprod',
'cumsum',
'dist',
# 'mean',
'norm',
'prod',
'std',
'sum',
'var',
# Misc
'renorm'
]
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
FP32_FUNCS.append('mean')
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list.
_bmms = ['addbmm',
'baddbmm',
'bmm']
if utils.is_cuda_enabled():
# workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802
if utils.get_cuda_version() >= (9, 1, 0):
FP16_FUNCS.extend(_bmms)
else:
FP32_FUNCS.extend(_bmms)
# Multi-tensor fns that may need type promotion
CASTS = [
# Multi-tensor math
'addcdiv',
'addcmul',
'atan2',
'cross',
'bilinear',
'dot',
# Element-wise _or_ tensor-wise math
'add',
'div',
'mul',
# Comparison
'eq',
'equal',
'ge',
'gt',
'le',
'lt',
'ne'
]
# Functions that take sequence arguments. We need to inspect the whole
# sequence and cast to the widest type.
SEQUENCE_CASTS = [
'cat',
'stack'
]
import contextlib
import warnings
from .scaler import LossScaler, master_params
from ._amp_state import maybe_print
import numpy as np
class OptimWrapper(object):
def __init__(self, optimizer, amp_handle, num_loss):
self._optimizer = optimizer
self._amp_handle = amp_handle
self._num_loss = num_loss
self._loss_idx = 0
self._skip_next = [False] * num_loss
self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
@contextlib.contextmanager
def scale_loss(self, loss):
if not self._amp_handle.is_active():
yield loss
return
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads = []
if self._loss_idx > 0:
for p in master_params(self._optimizer):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
self._cur_loss_scaler().clear_overflow_state()
self._cur_loss_scaler().unscale(
master_params(self._optimizer),
master_params(self._optimizer),
loss_scale)
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
self._loss_idx += 1
if len(cached_grads) > 0:
for p, cached_grad in zip(master_params(self._optimizer),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
return self._loss_scaler[self._loss_idx]
def step(self, closure=None):
if not self._amp_handle.is_active():
return self._optimizer.step(closure=closure)
self._loss_idx = 0
for group in self._optimizer.param_groups:
for p in group['params']:
self._amp_handle.remove_cache(p)
if closure is not None:
raise NotImplementedError(
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
maybe_print('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
else:
return self._optimizer.step(closure=closure)
# Forward any attribute lookups
def __getattr__(self, attr):
return getattr(self._optimizer, attr)
# Forward all torch.optim.Optimizer methods
def __getstate__(self):
return self._optimizer.__getstate__()
def __setstate__(self):
return self._optimizer.__setstate__()
def __repr__(self):
return self._optimizer.__repr__()
def state_dict(self):
return self._optimizer.state_dict()
def load_state_dict(self, state_dict):
return self._optimizer.load_state_dict(state_dict)
def zero_grad(self):
return self._optimizer.zero_grad()
def add_param_group(self, param_group):
return self._optimizer.add_param_group(param_group)
from . import utils, wrap
import torch
_VF = torch._C._VariableFunctions
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
def _gen_VF_wrapper(name):
def wrapper(*args, **kwargs):
return getattr(_VF, name)(*args, **kwargs)
return wrapper
# Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version.
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
# imported at module scope within torch.nn.modules.rnn). This should
# not affect third-party importers of _VF.py.
class VariableFunctionsShim(object):
def __init__(self):
for name in RNN_NAMES:
for suffix in ['', '_cell']:
fn_name = name + suffix
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
def has_old_rnns():
try:
torch.nn.backends.thnn.backend.LSTMCell
return True
except:
return False
def whitelist_rnn_cells(handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
mod = torch.nn.backends.thnn.backend
else:
fn_names = [x + '_cell' for x in RNN_NAMES]
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, VariableFunctionsShim)
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
# Special handling of `backward` for fused gru / lstm:
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
import torch
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
if master_grad is not model_grad: # copy_ probably internally short-circuits this
master_grad.copy_(model_grad)
if scale != 1.0:
master_grad.mul_(scale)
return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.data.to(master_grad.dtype)
master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
return False
class LossScaler(object):
warned_no_fused_kernel = False
warned_unscaling_non_fp32_grad = False
has_fused_kernel = False
def __init__(self,
loss_scale,
init_scale=2.**16,
scale_factor=2.,
scale_window=2000,
min_loss_scale=None,
max_loss_scale=2.**24):
if loss_scale == "dynamic":
self.dynamic = True
self._loss_scale = min(max_loss_scale, init_scale)
else:
self.dynamic = False
self._loss_scale = loss_scale
self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._scale_seq_len = scale_window
self._unskipped = 0
self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
else:
if not LossScaler.warned_no_fused_kernel:
maybe_print(
"Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"possibly because apex was installed without --cuda_ext --cpp_ext. "
"Using Python fallback. Original ImportError was: " +
repr(multi_tensor_applier.import_err),
True)
LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True
def loss_scale(self):
return self._loss_scale
def unscale_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads):
if model is not None:
if not LossScaler.warned_unscaling_non_fp32_grad:
if master.dtype != torch.float32:
maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python(model,
master,
1./scale,
self.dynamic)
if self._has_overflow and self.dynamic:
break
# unused_scale keeps some of the old API alive for hopefully a short time.
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
if self._has_overflow:
return
scale = self._loss_scale
if scale_override is not None:
scale = scale_override
if scale == 1.0 and models_are_masters and not self.dynamic:
return
if LossScaler.has_fused_kernel:
# if (not LossScaler.warned_unscaling_non_fp32_grad
# and master_grads[0].dtype == torch.float16):
# print("Warning: unscaling grads that are not FP32. "
# "Unscaling non-fp32 grads may indicate an error. "
# "When using Amp, you don't need to call .half() on your model.")
# # Setting this to True unconditionally allows the possibility of an escape
# # if never-before-seen non-fp32 grads are created in some later iteration.
# LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
self._overflow_buf,
[model_grads, master_grads],
1./scale)
else:
self.unscale_python(model_grads, master_grads, scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
# self._has_overflow = self._overflow_buf.item()
def unscale_with_stashed_python(self,
model_grads,
stashed_master_grads,
master_grads,
a,
b):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None:
continue
else:
if not LossScaler.warned_unscaling_non_fp32_grad:
if master.dtype != torch.float32:
maybe_print(
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = axpby_check_overflow_python(model,
stashed,
master,
a,
b,
self.dynamic)
if self._has_overflow and self.dynamic:
break
def unscale_with_stashed(self,
model_grads,
stashed_master_grads,
master_grads,
scale_override=None):
if self._has_overflow:
return
grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad
and master_grads[0].dtype == torch.float16):
print("Warning: unscaling grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
# Setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf,
[model_grads, stashed_master_grads, master_grads],
out_scale/grads_have_scale, # 1./scale,
out_scale/stashed_have_scale, # 1.0,
0) # check only arg 0, aka the incoming model grads, for infs
else:
self.unscale_with_stashed_python(model_grads,
stashed_master_grads,
master_grads,
out_scale/grads_have_scale,
out_scale/stashed_have_scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
# self._has_overflow = self._overflow_buf.item()
def clear_overflow_state(self):
self._has_overflow = False
if self.has_fused_kernel:
self._overflow_buf.zero_()
# Separate so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item()
if self._has_overflow and self.dynamic:
should_skip = True
if(self._min_loss_scale):
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
else:
self._loss_scale = self._loss_scale/2.
self._unskipped = 0
else:
should_skip = False
self._unskipped += 1
if self._unskipped == self._scale_seq_len and self.dynamic:
self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
self._unskipped = 0
return should_skip
from . import compat
import functools
import itertools
import torch
def is_cuda_enabled():
return torch.version.cuda is not None
def get_cuda_version():
return tuple(int(x) for x in torch.version.cuda.split('.'))
def is_fp_tensor(x):
if is_nested(x):
# Fast-fail version of all(is_fp_tensor)
for y in x:
if not is_fp_tensor(y):
return False
return True
return compat.is_tensor_like(x) and compat.is_floating_point(x)
def is_nested(x):
return isinstance(x, tuple) or isinstance(x, list)
def should_cache(x):
if is_nested(x):
# Fast-fail version of all(should_cache)
for y in x:
if not should_cache(y):
return False
return True
return isinstance(x, torch.nn.parameter.Parameter) and \
type_string(x) == 'FloatTensor'
def collect_fp_tensor_types(args, kwargs):
def collect_types(x, types):
if is_nested(x):
for y in x:
collect_types(y, types)
else:
types.add(type_string(x))
all_args = itertools.chain(args, kwargs.values())
types = set()
for x in all_args:
if is_fp_tensor(x):
collect_types(x, types)
return types
def type_string(x):
return x.type().split('.')[-1]
def maybe_half(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_half(y) for y in x])
if not x.is_cuda or type_string(x) == 'HalfTensor':
return x
else:
if verbose:
print('Float->Half ({})'.format(name))
return x.half()
def maybe_float(x, name='', verbose=False):
if is_nested(x):
return type(x)([maybe_float(y) for y in x])
if not x.is_cuda or type_string(x) == 'FloatTensor':
return x
else:
if verbose:
print('Half->Float ({})'.format(name))
return x.float()
# NB: returneds casted `args`, mutates `kwargs` in-place
def casted_args(cast_fn, args, kwargs):
new_args = []
for x in args:
if is_fp_tensor(x):
new_args.append(cast_fn(x))
else:
new_args.append(x)
for k in kwargs:
val = kwargs[k]
if is_fp_tensor(val):
kwargs[k] = cast_fn(val)
return new_args
def cached_cast(cast_fn, x, cache):
if is_nested(x):
return type(x)([cached_cast(y) for y in x])
if x in cache:
cached_x = cache[x]
if x.requires_grad and cached_x.requires_grad:
# Make sure x is actually cached_x's autograd parent.
if len(cached_x.grad_fn.next_functions) > 1 and cached_x.grad_fn.next_functions[1][0].variable is not x:
raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.")
# During eval, it's possible to end up caching casted weights with
# requires_grad=False. On the next training iter, if cached_x is found
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x]
else:
return cached_x
casted_x = cast_fn(x)
cache[x] = casted_x
return casted_x
def verbosify(cast_fn, fn_name, verbose):
if verbose:
return functools.partial(cast_fn, name=fn_name, verbose=verbose)
else:
return cast_fn
def as_inplace(fns):
for x in fns:
yield x + '_'
def has_func(mod, fn):
if isinstance(mod, dict):
return fn in mod
else:
return hasattr(mod, fn)
def get_func(mod, fn):
if isinstance(mod, dict):
return mod[fn]
else:
return getattr(mod, fn)
def set_func(mod, fn, new_fn):
if isinstance(mod, dict):
mod[fn] = new_fn
else:
setattr(mod, fn, new_fn)
def set_func_save(handle, mod, fn, new_fn):
cur_fn = get_func(mod, fn)
handle._save_func(mod, fn, cur_fn)
set_func(mod, fn, new_fn)
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
# to this forward call, not the flat buffer.
# - The ordering of weights in the flat buffer is...idiosyncratic.
# First problem is solved with combination of set_ (to set up
# correct storage) and copy_ (so the fp16 weight derives from the
# fp32 one in autograd.
# Second is solved by doing ptr arithmetic on the fp32 weights
# to derive the correct offset.
#
# TODO: maybe this should actually use
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
# on first iter and cache the right offsets. Ugh.
def synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0][0].data_ptr()
for layer_weights in fp32_weights:
fp16_layer_weights = []
for w_fp32 in layer_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_layer_weights.append(w_fp16)
fp16_weights.append(fp16_layer_weights)
return fp16_weights
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
from . import compat
from . import utils
from ._amp_state import _amp_state
from . import rnn_compat
import functools
import torch
def make_cast_wrapper(orig_fn, cast_fn, handle,
try_caching=False):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not handle.is_active():
return orig_fn(*args, **kwargs)
if try_caching and handle.has_cache:
args = list(args)
for i in range(len(args)):
if utils.should_cache(args[i]):
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
for k in kwargs:
if utils.should_cache(kwargs[k]):
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
return wrapper
def cached_cast(mod, fn, cast_fn, handle,
try_caching=False, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
# is on the new API and I am free to get rid of handle, I can clean this up.
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
types = utils.collect_fp_tensor_types(args, kwargs)
if len(types) <= 1:
return orig_fn(*args, **kwargs)
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
new_args = utils.casted_args(cast_fn,
args,
kwargs)
return orig_fn(*new_args, **kwargs)
else:
raise NotImplementedError('Do not know how to handle ' +
'these types to promote: {}'
.format(types))
return wrapper
def promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func_save(handle, mod, fn, wrapper)
def sequence_promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(seq, *args, **kwargs):
if not _amp_state.handle.is_active():
return orig_fn(seq, *args, **kwargs)
types = set([utils.type_string(x) for x in seq])
if len(types) <= 1:
return orig_fn(seq, *args, **kwargs)
elif types == set(['HalfTensor', 'FloatTensor']):
cast_seq = utils.casted_args(maybe_float,
seq, {})
return orig_fn(cast_seq, *args, **kwargs)
else:
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def promote_match_arg0(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if not _amp_state.handle.is_active():
return orig_fn(arg0, *args, **kwargs)
if utils.type_string(arg0) == 'HalfTensor':
cast_fn = utils.maybe_half
elif utils.type_string(arg0) == 'FloatTensor':
cast_fn = utils.maybe_float
else:
return orig_fn(arg0, *args, **kwargs)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
types = utils.collect_fp_tensor_types(args, kwargs)
if 'HalfTensor' in types:
if custom_err_msg:
raise NotImplementedError(custom_err_msg)
else:
raise NotImplementedError('Cannot call in-place function ' +
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_arg0_half(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(arg0, *args, **kwargs):
assert compat.is_tensor_like(arg0)
if utils.type_string(arg0) == 'HalfTensor':
raise NotImplementedError('Cannot call in-place method ' +
'{} on fp16 Tensors.'.format(fn))
else:
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
# - Will call into either CudnnRNN or AutogradRNN
# - Each of these are factory functions that return a per-iter
# `forward` function
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
flat_weight = kwargs.get('flat_weight')
if flat_weight is not None:
# We replace `flat_weight` with an uninitialized fp16
# Tensor. The "actual" weight tensors (provided in `forward`),
# will then be set up as ptrs into the buffer and have the
# corresponding fp32 values copied in.
# We need to call `copy` on the "actual" weights so that the
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert utils.type_string(flat_weight) == 'FloatTensor'
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
else:
flat_weight_fp16 = torch.empty_like(flat_weight,
dtype=torch.float16)
kwargs['flat_weight'] = flat_weight_fp16
else:
flat_weight_fp16 = None
forward = orig_rnn(*args, **kwargs)
@functools.wraps(forward)
def fwd_wrapper(*fargs, **fkwargs):
assert len(fargs) == 3 or len(fargs) == 4
inputs, weights, hiddens = fargs[:3]
assert utils.is_fp_tensor(inputs)
assert isinstance(weights, list)
cast_fn = utils.verbosify(utils.maybe_half,
fn,
verbose)
new_args = []
# 0) Inputs
new_args.append(cast_fn(inputs))
# 1) Weights
if flat_weight_fp16 is not None:
fp16_weights = utils.synthesize_flattened_rnn_weights(
weights, flat_weight_fp16, fn, verbose)
else:
fp16_weights = [[cast_fn(w) for w in layer]
for layer in weights]
new_args.append(fp16_weights)
# 2) Inputs: either a tuple (for LSTM) or single tensor
if isinstance(hiddens, tuple):
new_args.append(tuple(cast_fn(x) for x in hiddens))
elif utils.is_fp_tensor(hiddens):
new_args.append(cast_fn(hiddens))
else:
# Hiddens can, in principle, be `None` -- pass through
new_args.append(hiddens)
# 3) Batch sizes (0.4 or later only)
if len(fargs) == 4:
new_args.append(fargs[3])
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
# For rnn backend calls that route through _rnn_impls, we must patch the ref
# that _rnn_impls stashed. For rnn backend calls that directly invoke
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
mod = torch.nn.modules.rnn._rnn_impls
else:
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
fn = fn.lower()
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
assert len(args) == 9
assert len(kwargs) == 0
if not _amp_state.handle.is_active():
return orig_fn(*args, **kwargs)
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
params_idx = 3 # PackedSequence case
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
else:
new_args.append(arg)
return orig_fn(*new_args)
utils.set_func_save(handle, mod, fn, wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
return
orig_fn = utils.get_func(mod, fn)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func_save(handle, mod, fn, wrapper)
from .bottleneck import Bottleneck, SpatialBottleneck
from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer
This diff is collapsed.
import torch
import torch.distributed as dist
from torch import nn
import nccl_p2p_cuda as inc
import peer_memory_cuda as pm
# Communication free halo exchanger.
# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs
# NB! This is only useful for performance testing.
# NB! Do not use for actual production runs
class HaloExchanger(object):
def __init__(self, ranks, rank_in_group):
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
self.stream3 = torch.cuda.Stream()
self.group_size = len(ranks)
self.ranks = ranks
self.rank_in_group = rank_in_group
self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size
self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size
self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1
self.left_zero = True if rank_in_group == 0 else False
self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1
self.right_zero = True if rank_in_group == self.group_size - 1 else False
class HaloExchangerNoComm(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
return right_output_halo, left_output_halo
else:
left_input_halo.copy_(right_output_halo)
right_input_halo.copy_(left_output_halo)
class HaloExchangerAllGather(HaloExchanger):
def __init__(self, ranks, rank_in_group, comm):
super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group)
# self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks)
self.comm = comm
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
N,Hh,W,C = list(left_output_halo.shape)
send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
send_halos[:,:Hh,:,:].copy_(left_output_halo)
send_halos[:,Hh:,:,:].copy_(right_output_halo)
all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device)
all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)]
torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True)
ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:]
ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:]
if left_input_halo is None:
if self.left_zero:
ag_left_input_halo.zero_()
if self.right_zero:
ag_right_input_halo.zero_()
return ag_left_input_halo, ag_right_input_halo
else:
if self.left_zero:
left_input_halo.zero_()
else:
left_input_halo.copy_(ag_left_input_halo)
if self.right_zero:
right_input_halo.zero_()
else:
right_input_halo.copy_(ag_right_input_halo)
class HaloExchangerSendRecv(HaloExchanger):
def __init__(self, ranks, rank_in_group):
super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group)
nccl_id = inc.get_unique_nccl_id(1).cuda()
torch.distributed.broadcast(nccl_id, 0)
nccl_id = nccl_id.cpu()
print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id)))
# Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl")
# This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence
# it cannot be accessed from another class.
# TODO: Figure out a way to avoid creating a second global communicator
assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank())
self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size())
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
if left_input_halo is None:
left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo)
return left_input_halo, right_input_halo
else:
inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo)
class HaloExchangerPeer(HaloExchanger):
def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=0):
super(HaloExchangerPeer, self).__init__(ranks, rank_in_group)
self.diagnostics = False
self.explicit_nhwc = explicit_nhwc
self.numSM = numSM
self.peer_pool = peer_pool
def _allocate_peer_tensor(self, halo):
# Compute size in bytes
# Note: Pad buffer so each CUDA block gets required buffer size
size = 4 * halo.numel() * halo.element_size()
size_per_block = 128 * 2 * 16 # 128 threads each require two 128b buffers
size = (size + size_per_block - 1) // size_per_block * size_per_block
# Construct dtype peer buffer with desired size
shape = [1, 1, 1, size // halo.element_size()]
return self.peer_pool.allocate_peer_tensors(shape, halo.dtype, False, True)
def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None):
inplace = False if left_input_halo is None and right_input_halo is None else True
if not inplace:
left_input_halo = torch.empty_like(right_output_halo)
right_input_halo = torch.empty_like(left_output_halo)
channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc
left_tx = self._allocate_peer_tensor(left_input_halo)
right_tx = self._allocate_peer_tensor(right_input_halo)
pm.push_pull_halos_1d(
self.diagnostics, self.explicit_nhwc, self.numSM, self.rank_in_group,
self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo,
self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo,
)
if not inplace:
return left_input_halo, right_input_halo
# Class that combines input volume with halos from neighbors (1d).
class HaloPadder:
def __init__(self, halo_ex):
self.halo_ex = halo_ex
self.stream1 = torch.cuda.Stream()
self.stream2 = torch.cuda.Stream()
def __call__(self, y, half_halo, explicit_nhwc, H_split):
channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last)
if explicit_nhwc:
N,H,W,C = list(y.shape)
if H_split:
padded_shape = [N,H+2*half_halo,W,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:half_halo,:,:]
ymid = ypad[:,half_halo:H+half_halo,:,:]
yright = ypad[:,H+half_halo:H+2*half_halo,:,:]
oleft = y[:,:half_halo,:,:]
oright = y[:,H-half_halo:,:,:]
else:
padded_shape = [N,H,W+2*half_halo,C]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:W+half_halo,:]
yright = ypad[:,:,W+half_halo:W+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,W-half_halo:,:]
else:
N,C,H,W = list(y.shape)
if H_split:
padded_shape = [N,C,H+2*half_halo,W]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:half_halo,:]
ymid = ypad[:,:,half_halo:H+half_halo,:]
yright = ypad[:,:,H+half_halo:H+2*half_halo,:]
oleft = y[:,:,:half_halo,:]
oright = y[:,:,H-half_halo:,:]
else:
padded_shape = [N,C,H,W+2*half_halo]
ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last)
yleft = ypad[:,:,:,:half_halo]
ymid = ypad[:,:,:,half_halo:W+half_halo]
yright = ypad[:,:,:,W+half_halo:W+2*half_halo]
oleft = y[:,:,:,:half_halo]
oright = y[:,:,:,W-half_halo:]
with torch.cuda.stream(self.stream1):
self.halo_ex(oleft, oright, yleft, yright)
with torch.cuda.stream(self.stream2):
ymid.copy_(y)
return ypad
def wait(self):
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(self.stream1)
current_stream.wait_stream(self.stream2)
import torch
from bottleneck import Bottleneck
torch.manual_seed(23337)
# use True to print layerwise sum for all outputs in reference code path
DEBUG = False#True
for stride, o_channel in [(1,32), (1,128), (2,32)]:
print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel)
a_ = torch.randn(17,32,28,28)
a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_()
model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last)
# test model
b = model(a)
b.mean().backward()
d_grad = a.grad.float()
a.grad = None
torch.cuda.synchronize()
if DEBUG:
print("[DEBUG] ref dx :", d_grad.sum().item())
# print wgrad. we don't need to reset since later cpp print before accumulation
for i, w in enumerate(model.w_conv):
print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item())
wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float())
model.use_cudnn = True
model.zero_grad()
c = model(a)
c.mean().backward()
torch.cuda.synchronize()
print("comparing native and channels_last:")
print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item())
print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)):
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_()
nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half()
for p,q in zip(model.parameters(), nhwc_model.parameters()):
# model's storage is already in nhwc, we clone and assign to explicit nhwc model
q.data.copy_(p.data.permute(0,2,3,1).contiguous())
for p,q in zip(model.buffers(), nhwc_model.buffers()):
q.data.copy_(p.data)
d = nhwc_model(nhwc_a)
d.mean().backward()
torch.cuda.synchronize()
# reset reference to cudnn channels_last permute
#c_s = c.storage().tolist()
#d_s = d.storage().tolist()
#print(max([x-y for x,y in zip(c_s,d_s)]))
c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous()
d_grad = a.grad.float().permute(0,2,3,1).contiguous()
wgrads = []
for w in model.w_conv:
wgrads.append(w.grad.float().permute(0,2,3,1).contiguous())
torch.cuda.synchronize()
print("comparing nhwc and channels_last:")
print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item())
print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item())
for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)):
print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item())
from .clip_grad import clip_grad_norm_
from typing import Union, Iterable
import torch
_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except ImportError:
_kernel_import_succeeded = False
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
This is identical to torch.nn.utils.clip_grad_norm_, except it
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
in float32 and float16.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
# Trivial case
if len(parameters) == 0:
return torch.tensor(0.)
# Fallback implementation
if not (_kernel_import_succeeded
and norm_type == 2.0
and any(p.is_cuda for p in parameters)):
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm,
norm_type=norm_type,
error_if_nonfinite = error_if_nonfinite,
)
# Find fp32 and fp16 gradients on GPU
device = next(p.device for p in parameters if p.is_cuda)
grads_fp32, grads_fp16, grads_misc = [], [], []
for p in parameters:
grad = p.grad.detach()
if p.dtype == torch.float32 and p.device == device:
grads_fp32.append(grad)
elif p.dtype == torch.float16 and p.device == device:
grads_fp16.append(grad)
else:
grads_misc.append(grad)
# Compute gradient L2 norms
norms = []
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
if grads_fp32:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp32],
False,
)[0]
)
if grads_fp16:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp16],
False,
)[0],
)
for g in grads_misc:
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
total_norm = torch.linalg.norm(torch.cat(norms))
# Check for non-finite values
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
'`parameters` is non-finite, so it cannot be clipped. To disable '
'this error and scale the gradients by the non-finite norm anyway, '
'set `error_if_nonfinite=False`')
# Scale gradients
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if grads_fp32:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp32, grads_fp32],
clip_coef_clamped,
)
if grads_fp16:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp16, grads_fp16],
clip_coef_clamped,
)
for g in grads_misc:
g.mul_(clip_coef_clamped.to(g.device))
return total_norm
from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU, ConvFrozenScaleBiasReLU
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment