Commit 848c777d authored by Michael Carilli's avatar Michael Carilli
Browse files

FusedSGD tests passing for all opt_levels

parent c142714b
......@@ -91,6 +91,10 @@ def lazy_init_with_master_weights(self):
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
......@@ -106,20 +110,21 @@ def post_backward_models_are_masters(scaler, params, stashed_grads, scale_overri
else: # param.grad is None and stashed_grad is None
continue
# unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale.
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
None, # unused_scale, currently present to avoid API breakage elsewhere
models_are_masters=True,
scale_override=scale_override)
scale_override=grads_have_scale/out_scale)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash,
scale_override=scale_override)
scale_override=(grads_have_scale, stashed_have_scale, out_scale))
# Clear the stash.
for i in range(len(stashed_grads)):
......@@ -323,27 +328,25 @@ def post_backward_with_master_weights_FusedSGD(self, scaler):
if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler)
else:
# TODO: handle gradient clipping and removal of any lingering scale here.
stash = self._amp_stash
self._amp_lazy_init()
current_scale = scaler.loss_scale()
out_scale = current_scale
grads_have_scale = scaler.loss_scale()
stashed_have_scale = self.most_recent_scale
out_scale = grads_have_scale
if self.scale_set_by_backward:
out_scale = min(current_scale, self.most_recent_scale)
scale_adjustment = out_scale/current_scale
out_scale = min(grads_have_scale, self.most_recent_scale)
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
# Grads created by this backward pass have been scaled by current_scale.
# unscale() implements grads*1/scale, so "scale" should be current_scale/out_scale
# unscale_with_stashed() implements grads*1/scale + stashed_grads*1.
# stashed_grads are scaled by self.most_recent_scale.
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
post_backward_models_are_masters(scaler, params, stashed_grads,
(grads_have_scale, stashed_have_scale, out_scale))
self.most_recent_scale = out_scale
self.scale_set_by_backward = True
......
......@@ -132,10 +132,15 @@ def scale_loss(loss,
maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale()))
# TODO: I don't like the special casing for different optimizer implementations.
# Maybe skip should delegate to a method owned by the optimizers themselves.
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in opt._amp_stash.all_fp32_from_fp16_params:
param.grad = None
if hasattr(opt, "most_recent_scale"):
opt.most_recent_scale = 1.0
opt.scale_set_by_backward = False
opt.step = opt_step
opt._amp_stash.already_patched = False
return skip_step
......
......@@ -16,7 +16,7 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
master_grad.mul_(scale)
return False
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False):
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
# Exception handling for 18.04 compatibility
if check_overflow:
cpu_sum = float(model_grad.float().sum())
......@@ -26,9 +26,8 @@ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, ch
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
# master_grad.copy_(model_grad)
assert stashed_grad.dtype == master_grad.dtype
converted_model_grad = model_grad.to(master_grad.dtype)
stashed_grad.add_(scale, converted_model_grad)
master_grad.data = stashed_grad.data
converted_model_grad = model_grad.data.to(master_grad.dtype)
master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
return False
class LossScaler(object):
......@@ -125,7 +124,8 @@ class LossScaler(object):
model_grads,
stashed_master_grads,
master_grads,
scale):
a,
b):
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
if model is None and stashed is None:
continue
......@@ -140,7 +140,8 @@ class LossScaler(object):
self._has_overflow = axpby_check_overflow_python(model,
stashed,
master,
1./scale,
a,
b,
self.dynamic)
if self._has_overflow and self.dynamic:
break
......@@ -153,9 +154,9 @@ class LossScaler(object):
if self._has_overflow:
return
scale = self._loss_scale
grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
if scale_override is not None:
scale = scale_override
grads_have_scale, stashed_have_scale, out_scale = scale_override
if LossScaler.has_fused_kernel:
if (not LossScaler.warned_unscaling_non_fp32_grad
......@@ -169,14 +170,15 @@ class LossScaler(object):
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
self._overflow_buf,
[model_grads, stashed_master_grads, master_grads],
1./scale,
1.0,
out_scale/grads_have_scale, # 1./scale,
out_scale/stashed_have_scale, # 1.0,
0) # check only arg 0, aka the incoming model grads, for infs
else:
self.unscale_with_stashed_python(model_grads,
stashed_master_grads,
master_grads,
scale)
out_scale/grads_have_scale,
out_scale/stashed_have_scale)
# Defer to update_scale
# If the fused kernel is available, we only need one D2H memcopy and sync.
......
......@@ -67,7 +67,7 @@ class FusedSGD(Optimizer):
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
......@@ -138,8 +138,8 @@ class FusedSGD(Optimizer):
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if materialize_master_grads:
fp16_params = [p for i, p in enumerate(
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
......
......@@ -154,11 +154,17 @@ void multi_tensor_sgd_cuda(
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if(num_tensors == 4)
for(int i = 0; i < tensor_lists[3].size(); i++)
AT_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
......@@ -241,6 +247,26 @@ void multi_tensor_sgd_cuda(
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
......@@ -77,191 +77,50 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (True, False):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], optimizer = amp.initialize(
[model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
if i != inject_inf:
if opt_level == "O2":
master_params = list(model0.parameters()) + list(model1.parameters())
else:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()] +
[param.grad.data.clone() for param in model2.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (True, False):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 2)
elif which_backward == 1:
which_models = (1, 2)
else:
iters = 2
which_models = (None,)
if inject_inf >= 0:
iters = 3
else:
iters = 2
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
models = [model0, model1]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], optimizer = amp.initialize(
[model0, model1, model2],
[model0, model1], optimizer = amp.initialize(
[model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
......@@ -285,53 +144,36 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
else:
optimizer.zero_grad()
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 1:
inj_model = model1
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
model1.weight1.grad[0] = float('inf')
if i != inject_inf:
if opt_level == "O2":
master_params = list(model0.parameters()) + list(model1.parameters()) + \
list(model2.parameters())
else:
master_params = amp.master_params(optimizer)
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
......@@ -342,6 +184,167 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()] +
[param.grad.data.clone() for param in model2.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 2)
elif which_backward == 1:
which_models = (1, 2)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], optimizer = amp.initialize(
[model0, model1, model2],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 1:
inj_model = model1
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
......@@ -422,119 +425,120 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (True, False):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
[model0, model1],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
model0 = MyModel(1)
model1 = MyModel(2)
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
models = [model0, model1]
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
loss0 = model0(self.x)
loss1 = model1(self.x)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
[model0, model1],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
if i != inject_inf:
if opt_level == "O2":
master_params = list(model0.parameters()) + list(model1.parameters())
else:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses2optimizers(self):
......@@ -647,145 +651,149 @@ class TestMultipleModelsOptimizersLosses(unittest.TestCase):
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (True, False):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 1)
elif which_backward == 1:
which_models = (2, 1)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
[model0, model1, model2],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 2:
inj_model = model2
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
if opt_level == "O2":
master_params = list(model0.parameters()) + \
list(model1.parameters()) + \
list(model2.parameters())
else:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf,
which_backward, which_model)][unskipped]):
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if opt_level == "O1":
_amp_state.handle._deactivate()
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 1)
elif which_backward == 1:
which_models = (2, 1)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
[model0, model1, model2],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 2:
inj_model = model2
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
if opt_level == "O2" and not materialize_master_grads:
master_params = list(model0.parameters()) + \
list(model1.parameters()) + \
list(model2.parameters())
else:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf,
which_backward, which_model)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment