Commit 848c777d authored by Michael Carilli's avatar Michael Carilli
Browse files

FusedSGD tests passing for all opt_levels

parent c142714b
...@@ -91,6 +91,10 @@ def lazy_init_with_master_weights(self): ...@@ -91,6 +91,10 @@ def lazy_init_with_master_weights(self):
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None): def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
# This is a lot of python overhead... # This is a lot of python overhead...
grads_needing_unscale = [] grads_needing_unscale = []
grads_needing_unscale_with_stash = [] grads_needing_unscale_with_stash = []
...@@ -106,20 +110,21 @@ def post_backward_models_are_masters(scaler, params, stashed_grads, scale_overri ...@@ -106,20 +110,21 @@ def post_backward_models_are_masters(scaler, params, stashed_grads, scale_overri
else: # param.grad is None and stashed_grad is None else: # param.grad is None and stashed_grad is None
continue continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if len(grads_needing_unscale) > 0: if len(grads_needing_unscale) > 0:
scaler.unscale( scaler.unscale(
grads_needing_unscale, grads_needing_unscale,
grads_needing_unscale, grads_needing_unscale,
scaler.loss_scale(), None, # unused_scale, currently present to avoid API breakage elsewhere
models_are_masters=True, models_are_masters=True,
scale_override=scale_override) scale_override=grads_have_scale/out_scale)
if len(grads_needing_unscale_with_stash) > 0: if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed( scaler.unscale_with_stashed(
grads_needing_unscale_with_stash, grads_needing_unscale_with_stash,
stashed, stashed,
grads_needing_unscale_with_stash, grads_needing_unscale_with_stash,
scale_override=scale_override) scale_override=(grads_have_scale, stashed_have_scale, out_scale))
# Clear the stash. # Clear the stash.
for i in range(len(stashed_grads)): for i in range(len(stashed_grads)):
...@@ -323,27 +328,25 @@ def post_backward_with_master_weights_FusedSGD(self, scaler): ...@@ -323,27 +328,25 @@ def post_backward_with_master_weights_FusedSGD(self, scaler):
if self.materialize_master_grads: if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler) post_backward_with_master_weights(self, scaler)
else: else:
# TODO: handle gradient clipping and removal of any lingering scale here.
stash = self._amp_stash stash = self._amp_stash
self._amp_lazy_init() self._amp_lazy_init()
current_scale = scaler.loss_scale() grads_have_scale = scaler.loss_scale()
out_scale = current_scale stashed_have_scale = self.most_recent_scale
out_scale = grads_have_scale
if self.scale_set_by_backward: if self.scale_set_by_backward:
out_scale = min(current_scale, self.most_recent_scale) out_scale = min(grads_have_scale, self.most_recent_scale)
scale_adjustment = out_scale/current_scale
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)) (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# Grads created by this backward pass have been scaled by current_scale.
# unscale() implements grads*1/scale, so "scale" should be current_scale/out_scale
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1. # unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale. # stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads) post_backward_models_are_masters(scaler, params, stashed_grads,
(grads_have_scale, stashed_have_scale, out_scale))
self.most_recent_scale = out_scale self.most_recent_scale = out_scale
self.scale_set_by_backward = True self.scale_set_by_backward = True
......
...@@ -132,10 +132,15 @@ def scale_loss(loss, ...@@ -132,10 +132,15 @@ def scale_loss(loss,
maybe_print(("Gradient overflow. Skipping step, loss scaler " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id, "{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale())) loss_scaler.loss_scale()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"): if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params: for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None param.grad = None
if hasattr(opt, "most_recent_scale"):
opt.most_recent_scale = 1.0
opt.scale_set_by_backward = False
opt.step = opt_step opt.step = opt_step
opt._amp_stash.already_patched = False opt._amp_stash.already_patched = False
return skip_step return skip_step
......
...@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F ...@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
master_grad.mul_(scale) master_grad.mul_(scale)
return False return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False): def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
if check_overflow: if check_overflow:
cpu_sum = float(model_grad.float().sum()) cpu_sum = float(model_grad.float().sum())
...@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch ...@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
# if master_grad is not model_grad: # copy_ probably internally short-circuits this # if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad) # master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.to(master_grad.dtype) converted_model_grad = model_grad.data.to(master_grad.dtype)
stashed_grad.add_(scale, converted_model_grad) master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
master_grad.data = stashed_grad.data
return False return False
class LossScaler(object): class LossScaler(object):
...@@ -125,7 +124,8 @@ class LossScaler(object): ...@@ -125,7 +124,8 @@ class LossScaler(object):
model_grads, model_grads,
stashed_master_grads, stashed_master_grads,
master_grads, master_grads,
scale): a,
b):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads): for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None: if model is None and stashed is None:
continue continue
...@@ -140,7 +140,8 @@ class LossScaler(object): ...@@ -140,7 +140,8 @@ class LossScaler(object):
self._has_overflow = axpby_check_overflow_python(model, self._has_overflow = axpby_check_overflow_python(model,
stashed, stashed,
master, master,
1./scale, a,
b,
self.dynamic) self.dynamic)
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
break break
...@@ -153,9 +154,9 @@ class LossScaler(object): ...@@ -153,9 +154,9 @@ class LossScaler(object):
if self._has_overflow: if self._has_overflow:
return return
scale = self._loss_scale grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
if scale_override is not None: if scale_override is not None:
scale = scale_override grads_have_scale, stashed_have_scale, out_scale = scale_override
if LossScaler.has_fused_kernel: if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad if (not LossScaler.warned_unscaling_non_fp32_grad
...@@ -169,14 +170,15 @@ class LossScaler(object): ...@@ -169,14 +170,15 @@ class LossScaler(object):
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda, multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf, self._overflow_buf,
[model_grads, stashed_master_grads, master_grads], [model_grads, stashed_master_grads, master_grads],
1./scale, out_scale/grads_have_scale, # 1./scale,
1.0, out_scale/stashed_have_scale, # 1.0,
0) # check only arg 0, aka the incoming model grads, for infs 0) # check only arg 0, aka the incoming model grads, for infs
else: else:
self.unscale_with_stashed_python(model_grads, self.unscale_with_stashed_python(model_grads,
stashed_master_grads, stashed_master_grads,
master_grads, master_grads,
scale) out_scale/grads_have_scale,
out_scale/stashed_have_scale)
# Defer to update_scale # Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
......
...@@ -67,7 +67,7 @@ class FusedSGD(Optimizer): ...@@ -67,7 +67,7 @@ class FusedSGD(Optimizer):
super(FusedSGD, self).__init__(params, defaults) super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0 self.most_recent_scale = 1.0
self.scale_set_by_backward = False self.scale_set_by_backward = False
...@@ -138,8 +138,8 @@ class FusedSGD(Optimizer): ...@@ -138,8 +138,8 @@ class FusedSGD(Optimizer):
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if materialize_master_grads: if self.materialize_master_grads:
fp16_params = [p for i, p in enumerate( fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None] stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
......
...@@ -154,11 +154,17 @@ void multi_tensor_sgd_cuda( ...@@ -154,11 +154,17 @@ void multi_tensor_sgd_cuda(
auto grad_type = tensor_lists[0][0].scalar_type(); auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type(); auto weight_type = tensor_lists[1][0].scalar_type();
if(num_tensors == 4)
for(int i = 0; i < tensor_lists[3].size(); i++)
AT_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
...@@ -241,6 +247,26 @@ void multi_tensor_sgd_cuda( ...@@ -241,6 +247,26 @@ void multi_tensor_sgd_cuda(
wd_after_momentum, wd_after_momentum,
scale); scale);
} }
// Case 4. fp32, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else else
{ {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment