Commit 80a3f3ca authored by Michael Carilli's avatar Michael Carilli
Browse files

Allow multi-tensor unscale to handle FP16 output, so it can also be used for...

Allow multi-tensor unscale to handle FP16 output, so it can also be used for copy-scatter. Rename some options.
parent 4cc1c1b4
from .amp import init, half_function, float_function, promote_function,\ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function register_half_function, register_float_function, register_promote_function
from .handle import scale_loss from .handle import scale_loss
from .frontend import register from .frontend import initialize
...@@ -88,7 +88,7 @@ def _initialize(models, optimizers, properties): ...@@ -88,7 +88,7 @@ def _initialize(models, optimizers, properties):
# if properties.master_weights: # if properties.master_weights:
if properties.cast_model_type: if properties.cast_model_type:
if properties.cast_batchnorm: if properties.keep_batchnorm_fp32:
for model in models: for model in models:
convert_network(model, properties.cast_model_type) convert_network(model, properties.cast_model_type)
else: else:
...@@ -120,7 +120,7 @@ def _initialize(models, optimizers, properties): ...@@ -120,7 +120,7 @@ def _initialize(models, optimizers, properties):
for optimizer in optimizers: for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale) optimizer.loss_scaler = LossScaler(properties.loss_scale)
if properties.cast_torch_functions: if properties.patch_torch_functions:
handle = amp_init(loss_scale=properties.loss_scale) handle = amp_init(loss_scale=properties.loss_scale)
if optimizers_was_list: if optimizers_was_list:
......
...@@ -14,12 +14,10 @@ class Properties(object): ...@@ -14,12 +14,10 @@ class Properties(object):
"enabled" : False, "enabled" : False,
"opt_level" : None, "opt_level" : None,
"cast_model_type" : None, "cast_model_type" : None,
"cast_torch_functions" : False, "patch_torch_functions" : False,
"cast_batchnorm" : None, "keep_batchnorm_fp32" : None,
"master_weights" : False, "master_weights" : False,
"loss_scale" : 1.0, "loss_scale" : 1.0,
"flatten_model_params" : False,
"flatten_master_params" : False,
"fused_optimizer" : False, "fused_optimizer" : False,
"enable_ddp_interop" : False} "enable_ddp_interop" : False}
...@@ -69,12 +67,10 @@ class O3: ...@@ -69,12 +67,10 @@ class O3:
properties.enabled = True properties.enabled = True
properties.opt_level = "O3" properties.opt_level = "O3"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.cast_torch_functions = False properties.patch_torch_functions = False
properties.cast_batchnorm = False properties.keep_batchnorm_fp32 = False
properties.master_weights = False properties.master_weights = False
properties.loss_scale = 1.0 properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -94,12 +90,10 @@ class O2: ...@@ -94,12 +90,10 @@ class O2:
properties.enabled = True properties.enabled = True
properties.opt_level = "O2" properties.opt_level = "O2"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.cast_torch_functions = False properties.patch_torch_functions = False
properties.cast_batchnorm = torch.float32 properties.keep_batchnorm_fp32 = torch.float32
properties.master_weights = True properties.master_weights = True
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -118,12 +112,10 @@ class O1: ...@@ -118,12 +112,10 @@ class O1:
properties.enabled = True properties.enabled = True
properties.opt_level = "O1" properties.opt_level = "O1"
properties.cast_model_type = False properties.cast_model_type = False
properties.cast_torch_functions = True properties.patch_torch_functions = True
properties.cast_batchnorm = False properties.keep_batchnorm_fp32 = False
properties.master_weights = False properties.master_weights = False
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -133,19 +125,16 @@ class O0: ...@@ -133,19 +125,16 @@ class O0:
brief = "O0: Pure FP32 training.\n" brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\ more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\ "types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like parameter flattening and DDP interop\n"\ "FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
"may still be requested.\n"
def __call__(self, properties): def __call__(self, properties):
properties.enabled = True properties.enabled = True
properties.opt_level = "O0" properties.opt_level = "O0"
properties.cast_model_type = torch.float32 properties.cast_model_type = torch.float32
properties.cast_torch_functions = False properties.patch_torch_functions = False
properties.cast_batchnorm = False properties.keep_batchnorm_fp32 = False
properties.master_weights = False properties.master_weights = False
properties.loss_scale = 1.0 properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.fused_optimizer = False properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -178,12 +167,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -178,12 +167,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
Expected kwargs: Expected kwargs:
opt_level=None, opt_level=None,
cast_model_type=None, cast_model_type=None,
cast_torch_functions=None, patch_torch_functions=None,
cast_batchnorm=None, keep_batchnorm_fp32=None,
master_weights=None, master_weights=None,
loss_scale=None, loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None): enable_ddp_interop=None):
""" """
if not enabled: if not enabled:
...@@ -218,12 +205,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs): ...@@ -218,12 +205,10 @@ def initialize(models, optimizers, enabled=True, opt_level=None, **kwargs):
def check_option_consistency(enabled=True, def check_option_consistency(enabled=True,
opt_level=None, opt_level=None,
cast_model_type=None, cast_model_type=None,
cast_torch_functions=None, patch_torch_functions=None,
cast_batchnorm=None, keep_batchnorm_fp32=None,
master_weights=None, master_weights=None,
loss_scale=None, loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None): enable_ddp_interop=None):
""" """
Utility function that enables users to quickly check if the option combination they intend Utility function that enables users to quickly check if the option combination they intend
......
...@@ -31,7 +31,7 @@ def scale_loss(loss, ...@@ -31,7 +31,7 @@ def scale_loss(loss,
# Needing to drop the cache here as well is an ugly gotcha. # Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit. # But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale # Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions: if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
return return
...@@ -60,7 +60,7 @@ def scale_loss(loss, ...@@ -60,7 +60,7 @@ def scale_loss(loss,
optimizer.step = skip_step optimizer.step = skip_step
# Probably ok to skip this if not delay_unscale # Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions: if _amp_state.opt_properties.patch_torch_functions:
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
......
...@@ -24,7 +24,7 @@ def scale_check_overflow_python(model_grad, scale, master_grad): ...@@ -24,7 +24,7 @@ def scale_check_overflow_python(model_grad, scale, master_grad):
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_fp16_grad = False warned_unscaling_non_fp32_grad = False
has_fused_kernel = False has_fused_kernel = False
def __init__(self, def __init__(self,
...@@ -49,7 +49,7 @@ class LossScaler(object): ...@@ -49,7 +49,7 @@ class LossScaler(object):
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
else: else:
if not LossScaler.warned_no_fused_kernel: if not LossScaler.warned_no_fused_kernel:
print("Warning: multi_tensor_applier fused downscale kernel is unavailable, " print("Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"possibly because apex was installed without --cuda_ext --cpp_ext. " "possibly because apex was installed without --cuda_ext --cpp_ext. "
"Using Python fallback. Original ImportError was: ", "Using Python fallback. Original ImportError was: ",
multi_tensor_applier.import_err) multi_tensor_applier.import_err)
...@@ -62,14 +62,14 @@ class LossScaler(object): ...@@ -62,14 +62,14 @@ class LossScaler(object):
def unscale_grads_python(self, model_grads, master_grads, scale): def unscale_grads_python(self, model_grads, master_grads, scale):
for model, master in zip(model_grads, master_grads): for model, master in zip(model_grads, master_grads):
if model is not None: if model is not None:
if (master.type() != "torch.cuda.FloatTensor" if not LossScaler.warned_unscaling_non_fp32_grad:
and not LossScaler.warned_fp16_grad): if master.type() != "torch.cuda.FloatTensor":
logger = logging.getLogger("apex.amp") logger = logging.getLogger("apex.amp")
logger.warning( logger.warning(
"Attempting to downscale {} grads. ".format(master.type()) + "Attempting to unscale a grad with type {} ".format(master.type()) +
"Downscaling non-fp32 grads may indicate an error. " "Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.") "When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True LossScaler.warned_unscaling_non_fp32_grad = True
self._has_overflow = scale_check_overflow_python( self._has_overflow = scale_check_overflow_python(
model, model,
1./scale, 1./scale,
...@@ -102,15 +102,21 @@ class LossScaler(object): ...@@ -102,15 +102,21 @@ class LossScaler(object):
# The master grads should never be fp16. The kernel can't handle that, so bail out # The master grads should never be fp16. The kernel can't handle that, so bail out
# and print a warning. This is overly conservative, and maybe we do want to enable # and print a warning. This is overly conservative, and maybe we do want to enable
# fast downscaling of fp16 grads eventually. # fast downscaling of fp16 grads eventually.
if any(grad.type() == "torch.cuda.HalfTensor" for grad in master_grads): if not LossScaler.warned_unscaling_non_fp32_grad:
self.unscale_grads_python(model_grads, master_grads, scale) if any(grad.type() != "torch.cuda.FloatTensor" for grad in master_grads):
else: logger = logging.getLogger("apex.amp")
# This is inefficient if opt_level is O1 and loss scale is 1.0. But to elide logger.warning(
# the launch, I would need to make sure the model grads are the master grads. "Attempting to unscale grads that are not FP32. "
# The O(N) checks are proliferating... "Unscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
# Warning: setting this to True unconditionally allows the possibility of an escape
# if never-before-seen non-fp32 grads are created in some later iteration.
LossScaler.warned_unscaling_non_fp32_grad = True
self._overflow_buf.zero_() self._overflow_buf.zero_()
# handle case of opt_level O1 and loss_scale 1.0. There's also some # handle case of opt_level O1 and loss_scale 1.0. There's also some
# special-cased yields in scale_loss to potentially short-circuit earlier. # special-cased yields in scale_loss to potentially short-circuit earlier.
# TODO: Profile and find out if all the O(N) list processing in unscale()
# is a bottleneck.
if scale == 1.0 and all_same and not self.dynamic: if scale == 1.0 and all_same and not self.dynamic:
return return
else: else:
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
...@@ -18,8 +20,8 @@ template<int n> struct TensorList ...@@ -18,8 +20,8 @@ template<int n> struct TensorList
{ {
void* addresses[n][depth_to_max_tensors[n-1]]; void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]]; int sizes[depth_to_max_tensors[n-1]];
int block_to_tensor[depth_to_max_blocks[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
}; };
...@@ -44,7 +46,7 @@ void multi_tensor_apply( ...@@ -44,7 +46,7 @@ void multi_tensor_apply(
T callable, T callable,
ArgTypes... args) ArgTypes... args)
{ {
AT_CHECK(tensor_lists.size() > 0, "tensor_lists.size() is not > 0"); AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
...@@ -53,6 +55,7 @@ void multi_tensor_apply( ...@@ -53,6 +55,7 @@ void multi_tensor_apply(
AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails.
AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous."); AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda."); AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename in_t> template<typename in_t, typename out_t>
struct ScaleFunctor struct ScaleFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -34,7 +34,7 @@ struct ScaleFunctor ...@@ -34,7 +34,7 @@ struct ScaleFunctor
in_t* in = (in_t*)tl.addresses[0][tensor_loc]; in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size; in += chunk_idx*chunk_size;
float* out = (float*)tl.addresses[1][tensor_loc]; out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size; out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -65,7 +65,7 @@ struct ScaleFunctor ...@@ -65,7 +65,7 @@ struct ScaleFunctor
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
if(isfinite(incoming_vals[ii])) if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale; out[i] = static_cast<out_t>(incoming_vals[ii]*scale);
else else
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
...@@ -96,13 +96,30 @@ void multi_tensor_scale_cuda( ...@@ -96,13 +96,30 @@ void multi_tensor_scale_cuda(
[&] [&]
{ {
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].type().scalarType())
{
case at::ScalarType::Half:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
ScaleFunctor<scalar_t>(), ScaleFunctor<scalar_t, float>(),
scale); scale);
break;
default:
AT_ERROR("multi_tensor_scale_cuda not implemented for output type = ",
tensor_lists[1][0].type().toString());
}
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -34,14 +34,14 @@ class TestMultiTensorScale(unittest.TestCase): ...@@ -34,14 +34,14 @@ class TestMultiTensorScale(unittest.TestCase):
pass pass
# The tensor creation here is written for convenience, not speed. # The tensor creation here is written for convenience, not speed.
def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, inplace=False): def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
self.overflow_buf.zero_() self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale) a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = [] out_list = []
for i in range(repeat_tensors): for i in range(repeat_tensors):
out_list += [a.clone(), b.clone()] out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace: if inplace:
in_list = out_list in_list = out_list
...@@ -50,17 +50,17 @@ class TestMultiTensorScale(unittest.TestCase): ...@@ -50,17 +50,17 @@ class TestMultiTensorScale(unittest.TestCase):
applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
self.assertTrue(all([torch.allclose(out, self.ref) for out in out_list])) self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, t, ind, val, inplace=False): def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
self.overflow_buf.zero_() self.overflow_buf.zero_()
a = torch.cuda.FloatTensor(sizea).fill_(self.scale) a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
out_list = [] out_list = []
for i in range(repeat_tensors): for i in range(repeat_tensors):
out_list += [a.clone(), b.clone()] out_list += [a.clone().to(out_type), b.clone().to(out_type)]
if inplace: if inplace:
in_list = out_list in_list = out_list
...@@ -103,21 +103,22 @@ class TestMultiTensorScale(unittest.TestCase): ...@@ -103,21 +103,22 @@ class TestMultiTensorScale(unittest.TestCase):
repeat_tensors = ( repeat_tensors = (
1, 1,
55) 55)
dtype_inplace_pairs = (
(torch.float16, False),
(torch.float32, False),
(torch.float32, True))
for sizea, sizeb in input_size_pairs: for sizea, sizeb in input_size_pairs:
for applier in appliers: for applier in appliers:
for repeat in repeat_tensors: for repeat in repeat_tensors:
for dtype, inplace in dtype_inplace_pairs: for in_type in (torch.float32, torch.float16):
self.downscale(sizea, sizeb, applier, repeat, dtype, inplace=inplace) for out_type in (torch.float32, torch.float16):
self.find_inf(sizea, sizeb, applier, repeat, dtype, for inplace in (True, False):
if inplace is True and (out_type is not in_type):
continue
else:
self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
0, 0, float('nan'), inplace=inplace) 0, 0, float('nan'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, dtype, self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*repeat-1, sizeb-1, float('inf'), inplace=inplace) 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
self.find_inf(sizea, sizeb, applier, repeat, dtype, self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
2*(repeat//2), sizea//2, float('inf'), inplace=inplace) 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment