Commit 8f53411a authored by Michael Carilli's avatar Michael Carilli
Browse files

Removing deprecated scale_check_overflow kernel

parent 62ce27d2
...@@ -3,8 +3,6 @@ from ..multi_tensor_apply import multi_tensor_applier ...@@ -3,8 +3,6 @@ from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product from itertools import product
# from apex_C import scale_check_overflow
def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False): def scale_check_overflow_python(model_grad, scale, master_grad, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
......
...@@ -6,35 +6,7 @@ void multi_tensor_scale_cuda( ...@@ -6,35 +6,7 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void scale_check_overflow_cuda(
const at::Tensor& grads,
float scale,
const at::Tensor& d_buf,
const at::Tensor& downscaled_grads);
void scale_check_overflow(
at::Tensor grads,
float scale,
at::Tensor overflow_buf,
at::Tensor downscaled_grads)
// const at::optional<at::Tensor> downscaled_grads)
{
AT_CHECK(grads.type().is_cuda(), "grads must be a CUDA tensor");
AT_CHECK(grads.is_contiguous(), "grads must be contiguous");
AT_CHECK(overflow_buf.type().is_cuda(), "overflow_buf must be a CUDA tensor");
AT_CHECK(overflow_buf.is_contiguous(), "overflow_buf must be contiguous");
AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
// Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 256
#define NBLOCKS 160*4
#define ILP 4
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32).
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
float* out,
int n,
float scale,
volatile int* overflow_global)
{
__shared__ int overflow;
float incoming_vals[4];
// Non-divergent exit condition for the __syncthreads
for(int chunk_start = blockIdx.x*blockDim.x*ILP;
chunk_start < n;
chunk_start += gridDim.x*blockDim.x*ILP)
{
if(threadIdx.x == 0)
overflow = *overflow_global;
__syncthreads();
if(overflow == 1)
break;
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_vals[ii] = 0;
int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(i < n)
incoming_vals[ii] = static_cast<float>(in[i]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(i < n)
if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale;
else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
} // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
} // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
} // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
void scale_check_overflow_cuda
(const at::Tensor& grads,
float scale,
const at::Tensor& overflow_buf,
const at::Tensor& downscaled_grads)
{
using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int n = grads.numel();
// Lock the output (downscaled) type to float.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
"scale_check_overflow_cuda",
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
(grads.data<scalar_t>(),
downscaled_grads.data<float>(),
n,
scale,
overflow_buf.data<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -48,7 +48,6 @@ if "--cuda_ext" in sys.argv: ...@@ -48,7 +48,6 @@ if "--cuda_ext" in sys.argv:
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/scale_check_overflow_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu'], 'csrc/multi_tensor_scale_kernel.cu'],
extra_compile_args={'cxx': ['-O3'], extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo', 'nvcc':['-lineinfo',
......
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
scale_check_overflow = amp_C.scale_check_overflow
disabled = False
except ImportError as err:
print("amp_C fused kernel unavailable, disabling TestScale. ImportError was ", err)
disabled = True
class TestScale(unittest.TestCase):
def setUp(self):
self.scale = 128.0
self.nx = 999
self.ny = 888
self.overflow_buf = torch.cuda.IntTensor([0])
self.fp16 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float16)
self.fp32 = torch.ones((self.ny, self.nx), device='cuda', dtype=torch.float32)
self.fp16_ref = torch.ones((1, 1), device='cuda', dtype=torch.float16)
self.fp32_ref = torch.ones((1, 1), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def downscale_test(self, input, output, ref):
self.overflow_buf.zero_()
input.fill_(1.0)
if input is not output:
output.fill_(3.0)
input.mul_(self.scale)
scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
self.assertTrue(torch.allclose(output, ref))
self.assertTrue(self.overflow_buf.item() == 0)
def find_inf_test(self, input, output, ref, x, y, val):
self.overflow_buf.zero_()
input.fill_(1.0)
if input is not output:
output.fill_(3.0)
input[x,y] = val
scale_check_overflow(input, 1./self.scale, self.overflow_buf, output)
self.assertTrue(self.overflow_buf.item())
# Currently, the fused kernel gives a hard error if you attempt to downscale
# into fp16 output, which imo is the desired behavior. Maybe someday we
# will learn otherwise.
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp16_to_fp16(self):
# self.downscale_test(self.fp16, self.fp16, self.fp16_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp16_to_fp32(self):
self.downscale_test(self.fp16, self.fp32, self.fp32_ref)
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp32_to_fp16(self):
# self.downscale_test(self.fp32, self.fp16, self.fp16_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp32_to_fp32(self):
self.downscale_test(self.fp32, self.fp32, self.fp32_ref)
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp16_to_fp32_find_inf_nan(self):
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, 0, 0, float('nan'))
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('inf'))
self.find_inf_test(self.fp16, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('nan'))
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fp32_to_fp32_find_inf_nan(self):
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, 0, 0, float('inf'))
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny//2, self.nx//2, float('nan'))
self.find_inf_test(self.fp32, self.fp32, self.fp32_ref, self.ny-1, self.nx-1, float('inf'))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment