Commit c763f0fe authored by Michael Carilli's avatar Michael Carilli
Browse files

materialize_master_weights for FusedSGD

parent f3528d99
......@@ -130,7 +130,8 @@ def prepare_backward_with_master_weights(self):
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
......@@ -298,31 +299,38 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash
if self.materialize_master_grads:
prepare_backward_with_master_weights(self)
else:
stash = self._amp_stash
self._amp_lazy_init()
self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler)
else:
# TODO: handle gradient clipping and removal of any lingering scale here.
stash = self._amp_stash
self._amp_lazy_init()
self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_no_master_weights_FusedSGD(self):
......
......@@ -51,7 +51,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False):
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
......@@ -67,6 +68,8 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
self.scale = 1.0
if multi_tensor_applier.available:
import amp_C
# Skip buffer
......@@ -130,18 +133,30 @@ class FusedSGD(Optimizer):
if explicit_master_params:
stash = self._amp_stash
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets= [[fp16_model_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params],
[fp32_grads, fp32_params, fp32_momentums]]
if materialize_master_grads:
fp16_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
......@@ -168,6 +183,7 @@ class FusedSGD(Optimizer):
group['lr'],
nesterov,
first_run,
self.wd_after_momentum)
self.wd_after_momentum,
self.scale)
return loss
......@@ -16,7 +16,8 @@ void multi_tensor_sgd_cuda(
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum);
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda(
int chunk_size,
......
......@@ -38,7 +38,8 @@ struct SGDFunctor
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum)
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem) return;
......@@ -82,7 +83,7 @@ struct SGDFunctor
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
......@@ -146,7 +147,8 @@ void multi_tensor_sgd_cuda(
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum)
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
......@@ -178,7 +180,8 @@ void multi_tensor_sgd_cuda(
lr,
nesterov,
first_run,
wd_after_momentum);
wd_after_momentum,
scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
......@@ -215,7 +218,8 @@ void multi_tensor_sgd_cuda(
lr,
nesterov,
first_run,
wd_after_momentum);
wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Half &&
......@@ -234,7 +238,8 @@ void multi_tensor_sgd_cuda(
lr,
nesterov,
first_run,
wd_after_momentum);
wd_after_momentum,
scale);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment