Commit 80a3f3ca authored by Michael Carilli's avatar Michael Carilli
Browse files

Allow multi-tensor unscale to handle FP16 output, so it can also be used for...

Allow multi-tensor unscale to handle FP16 output, so it can also be used for copy-scatter. Rename some options.
parent 4cc1c1b4
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
from .handle import scale_loss
from .frontend import register
from .frontend import initialize
......@@ -88,7 +88,7 @@ def _initialize(models, optimizers, properties):
# if properties.master_weights:
if properties.cast_model_type:
if properties.cast_batchnorm:
if properties.keep_batchnorm_fp32:
for model in models:
convert_network(model, properties.cast_model_type)
else:
......@@ -120,7 +120,7 @@ def _initialize(models, optimizers, properties):
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
if properties.cast_torch_functions:
if properties.patch_torch_functions:
handle = amp_init(loss_scale=properties.loss_scale)
if optimizers_was_list:
......
......@@ -14,12 +14,10 @@ class Properties(object):
"enabled" : False,
"opt_level" : None,
"cast_model_type" : None,
"cast_torch_functions" : False,
"cast_batchnorm" : None,
"patch_torch_functions" : False,
"keep_batchnorm_fp32" : None,
"master_weights" : False,
"loss_scale" : 1.0,
"flatten_model_params" : False,
"flatten_master_params" : False,
"fused_optimizer" : False,
"enable_ddp_interop" : False}
......@@ -69,12 +67,10 @@ class O3:
properties.enabled = True
properties.opt_level = "O3"
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
......@@ -94,12 +90,10 @@ class O2:
properties.enabled = True
properties.opt_level = "O2"
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = torch.float32
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = torch.float32
properties.master_weights = True
properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
......@@ -118,12 +112,10 @@ class O1:
properties.enabled = True
properties.opt_level = "O1"
properties.cast_model_type = False
properties.cast_torch_functions = True
properties.cast_batchnorm = False
properties.patch_torch_functions = True
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
......@@ -133,19 +125,16 @@ class O0:
brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like parameter flattening and DDP interop\n"\
"may still be requested.\n"
"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
def __call__(self, properties):
properties.enabled = True
properties.opt_level = "O0"
properties.cast_model_type = torch.float32
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.patch_torch_functions = False
properties.keep_batchnorm_fp32 = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
......@@ -178,12 +167,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
Expected kwargs:
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):
"""
if not enabled:
......@@ -218,12 +205,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
def check_option_consistency(enabled=True,
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
patch_torch_functions=None,
keep_batchnorm_fp32=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):
"""
Utility function that enables users to quickly check if the option combination they intend
......
......@@ -31,7 +31,7 @@ def scale_loss(loss,
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions:
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
return
......@@ -60,7 +60,7 @@ def scale_loss(loss,
optimizer.step = skip_step
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions:
if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache()
......
......@@ -24,7 +24,7 @@ def scale_check_overflow_python(model_grad, scale, master_grad):
class LossScaler(object):
warned_no_fused_kernel = False
warned_fp16_grad = False
warned_unscaling_non_fp32_grad = False
has_fused_kernel = False
def __init__(self,
......@@ -49,7 +49,7 @@ class LossScaler(object):
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
else:
if not LossScaler.warned_no_fused_kernel:
print("Warning: multi_tensor_applier fused downscale kernel is unavailable, "
print("Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"possibly because apex was installed without --cuda_ext --cpp_ext. "
"Using Python fallback. Original ImportError was: ",
multi_tensor_applier.import_err)
......@@ -62,14 +62,14 @@ class LossScaler(object):
def unscale_grads_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads):
if model is not None:
if (master.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad):
if not LossScaler.warned_unscaling_non_fp32_grad:
if master.type() != "torch.cuda.FloatTensor":
logger = logging.getLogger("apex.amp")
logger.warning(
"Attempting to downscale {} grads. ".format(master.type()) +
"Downscaling non-fp32 grads may indicate an error. "
"Attempting to unscale a grad with type {} ".format(master.type()) +
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True
LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python(
model,
1./scale,
......@@ -102,15 +102,21 @@ class LossScaler(object):
# The master grads should never be fp16. The kernel can't handle that, so bail out
# and print a warning. This is overly conservative, and maybe we do want to enable
# fast downscaling of fp16 grads eventually.
if any(grad.type() == "torch.cuda.HalfTensor" for grad in master_grads):
self.unscale_grads_python(model_grads, master_grads, scale)
else:
# This is inefficient if opt_level is O1 and loss scale is 1.0. But to elide
# the launch, I would need to make sure the model grads are the master grads.
# The O(N) checks are proliferating...
if not LossScaler.warned_unscaling_non_fp32_grad:
if any(grad.type() != "torch.cuda.FloatTensor" for grad in master_grads):
logger = logging.getLogger("apex.amp")
logger.warning(
"Attempting to unscale grads that are not FP32. "
"Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
# Warning: setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
self._overflow_buf.zero_()
# handle case of opt_level O1 and loss_scale 1.0. There's also some
# special-cased yields in scale_loss to potentially short-circuit earlier.
# TODO: Profile and find out if all the O(N) list processing in unscale()
# is a bottleneck.
if scale == 1.0 and all_same and not self.dynamic:
return
else:
......
......@@ -11,6 +11,8 @@
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
......@@ -18,8 +20,8 @@ template<int n> struct TensorList
{
void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]];
int block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
};
......@@ -44,7 +46,7 @@ void multi_tensor_apply(
T callable,
ArgTypes... args)
{
AT_CHECK(tensor_lists.size() > 0, "tensor_lists.size() is not > 0");
AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
......@@ -53,6 +55,7 @@ void multi_tensor_apply(
AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
......
......@@ -10,7 +10,7 @@
#define BLOCK_SIZE 512
#define ILP 4
template<typename in_t>
template<typename in_t, typename out_t>
struct ScaleFunctor
{
__device__ __forceinline__ void operator()(
......@@ -34,7 +34,7 @@ struct ScaleFunctor
in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size;
float* out = (float*)tl.addresses[1][tensor_loc];
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -65,7 +65,7 @@ struct ScaleFunctor
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale;
out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
else
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
......@@ -96,13 +96,30 @@ void multi_tensor_scale_cuda(
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].type().scalarType())
{
case at::ScalarType::Half:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t>(),
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
AT_ERROR("multi_tensor_scale_cuda not implemented for output type = ",
tensor_lists[1][0].type().toString());
}
});
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -34,14 +34,14 @@ class TestMultiTensorScale(unittest.TestCase):
pass
# The tensor creation here is written for convenience, not speed.
def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, inplace=False):
def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = []
for i in range(repeat_tensors):
out_list += [a.clone(), b.clone()]
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
......@@ -50,17 +50,17 @@ class TestMultiTensorScale(unittest.TestCase):
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(all([torch.allclose(out, self.ref) for out in out_list]))
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
self.assertTrue(self.overflow_buf.item() == 0)
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, t, ind, val, inplace=False):
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = []
for i in range(repeat_tensors):
out_list += [a.clone(), b.clone()]
out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace:
in_list = out_list
......@@ -103,21 +103,22 @@ class TestMultiTensorScale(unittest.TestCase):
repeat_tensors = (
1,
55)
dtype_inplace_pairs = (
(torch.float16, False),
(torch.float32, False),
(torch.float32, True))
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for dtype, inplace in dtype_inplace_pairs:
self.downscale(sizea, sizeb, applier, repeat, dtype, inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, dtype,
for in_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (out_type is not in_type):
continue
else:
self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
0, 0, float('nan'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, dtype,
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, dtype,
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment