Commit fd03f26a authored by Michael Carilli's avatar Michael Carilli
Browse files

Restoring fused inf/nan check + downscale kernel

parent 48299b0d
......@@ -57,7 +57,7 @@ class AmpHandle(object):
optimizer_step = optimizer.step
def skip_step():
logger = logging.getLogger('apex.amp')
logger.info('Gradient overflow, skipping update')
logger.warning('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
......
......@@ -2,8 +2,7 @@ import torch
# from apex_C import scale_check_overflow
# Python stopgap, until we get a future-proof kernel into upstream
def scale_check_overflow(d_grads, scale):
def scale_check_overflow_python(d_grads, scale):
# Exception handling for 18.04 compatibility
try:
cpu_sum = float(d_grads.float().sum())
......@@ -18,28 +17,49 @@ def scale_check_overflow(d_grads, scale):
return False
class LossScaler(object):
warned_no_fused_kernel = False
has_fused_kernel = False
def __init__(self):
self._loss_scale = 2.**16
self._max_loss_scale = 2.**24
self._scale_seq_len = 2000
self._unskipped = 0
self._has_overflow = False
# self._overflow_buf = torch.cuda.ByteTensor(1024,)
try:
import amp_C
LossScaler.has_fused_kernel = True
LossScaler.scale_check_overflow = amp_C.scale_check_overflow
self._overflow_buf = torch.cuda.ByteTensor(1024,)
except ImportError as err:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ", err)
LossScaler.has_fused_kernel = False
LossScaler.scale_check_overflow = scale_check_overflow_python
def loss_scale(self):
return self._loss_scale
def unscale_and_update(self, param_groups, scale):
# self._overflow_buf.zero_()
if LossScaler.has_fused_kernel:
self._overflow_buf.zero_()
self._has_overflow = False
for p in iter_params(param_groups):
if p.grad is not None:
self._has_overflow = scale_check_overflow(p.grad.data,
1. / scale)
if self._has_overflow:
break
if LossScaler.has_fused_kernel:
LossScaler.scale_check_overflow(p.grad.data,
1. / scale,
self._overflow_buf)
else:
self._has_overflow = LossScaler.scale_check_overflow(p.grad.data,
1. / scale)
if self._has_overflow:
break
# If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel:
self._has_overflow = self._overflow_buf.any()
# if self._overflow_buf.any():
if self._has_overflow:
should_skip = True
self._loss_scale /= 2.
......
#include <torch/extension.h>
void scale_check_overflow_cuda(const at::Tensor& d_grads, float scale, const at::Tensor& d_buf);
void scale_check_overflow(at::Tensor grads, float scale, at::Tensor overflow_buf)
{
AT_CHECK(grads.type().is_cuda(), "grads must be a CUDA tensor");
AT_CHECK(grads.is_contiguous(), "grads must be contiguous");
AT_CHECK(overflow_buf.type().is_cuda(), "overflow_buf must be a CUDA tensor");
AT_CHECK(overflow_buf.is_contiguous(), "overflow_buf must be contiguous");
// Make sure we are downscaling the FP32 master grads
AT_CHECK(grads.type().scalarType() == at::ScalarType::Float,
"grads supplied to scale_check_overflow should be fp32 (master grads).")
scale_check_overflow_cuda(grads, scale, overflow_buf);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#define BLOCK_SIZE 1024
#define MAX_BLOCKS 1024
// It makes sense to lock the type to "float" here because the downscaling
// should only be applied to the FP32 master gradients. Also, if "in" were
// a different type, it would require divergent code for the vectorized load logic.
// TODO:
// Update overflow check to use reduction from kernel_utils.cuh with
// ReduceOp from THCTensorMathReduce.cuh.
__global__ void scale_reduce_overflow
(float *in,
size_t n,
float scale,
uint8_t *overflow_out)
{
__shared__ uint8_t cta_overflow[BLOCK_SIZE];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
uint8_t my_overflow = 0;
for (int i = tid * 4; i < n; i+= stride * 4) {
if (i < (n - 3)) {
float4 f4 = ((float4*)in)[i / 4];
if (isfinite(f4.x)) {
f4.x *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.y)) {
f4.y *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.z)) {
f4.z *= scale;
} else {
my_overflow = 1;
}
if (isfinite(f4.w)) {
f4.w *= scale;
} else {
my_overflow = 1;
}
((float4*)in)[i / 4] = f4;
} else {
for (; i < n; ++i) {
if (isfinite(in[i])) {
in[i] *= scale;
} else {
my_overflow = 1;
}
}
}
}
int tIdx = threadIdx.x;
cta_overflow[tIdx] = my_overflow;
__syncthreads();
int participating = BLOCK_SIZE / 2;
while (participating > 0) {
if (tIdx < participating) {
cta_overflow[tIdx] = max(cta_overflow[tIdx],
cta_overflow[tIdx + participating]);
}
participating /= 2;
__syncthreads();
}
if (tIdx == 0) {
overflow_out[blockIdx.x] = max(cta_overflow[0],
overflow_out[blockIdx.x]);
}
}
void scale_check_overflow_cuda
(const at::Tensor& d_grads,
float scale,
const at::Tensor& d_buf)
{
using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t n = d_grads.numel();
size_t buf_n = d_buf.numel();
int num_blks = min((int(n) + BLOCK_SIZE - 1) / BLOCK_SIZE,
MAX_BLOCKS);
assert(buf_n >= num_blks);
scale_reduce_overflow<<<num_blks, BLOCK_SIZE, 0, stream>>>
(d_grads.data<float>(),
n,
scale,
d_buf.data<uint8_t>());
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -36,6 +36,10 @@ if "--cuda_ext" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None:
print("Warning: nvcc is not available. Ignoring --cuda-ext")
else:
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/scale_check_overflow.cpp',
'csrc/scale_check_overflow_kernel.cu']))
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/optimizers/csrc/fused_adam_cuda.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment