Commit c763f0fe authored by Michael Carilli's avatar Michael Carilli
Browse files

materialize_master_weights for FusedSGD

parent f3528d99
...@@ -130,7 +130,8 @@ def prepare_backward_with_master_weights(self): ...@@ -130,7 +130,8 @@ def prepare_backward_with_master_weights(self):
self._amp_lazy_init() self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
# Set up to leverage grad copy elision: # Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param.grad = None param.grad = None
# for i, param in enumerate(stash.all_fp32_from_fp16_params): # for i, param in enumerate(stash.all_fp32_from_fp16_params):
...@@ -298,31 +299,38 @@ def post_backward_no_master_weights_FusedAdam(self, scaler): ...@@ -298,31 +299,38 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params # FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads. # outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self): def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash if self.materialize_master_grads:
prepare_backward_with_master_weights(self)
else:
stash = self._amp_stash
self._amp_lazy_init() self._amp_lazy_init()
for i, param in enumerate(stash.all_fp16_params): for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision: # Set up to leverage grad copy elision:
param.grad = None param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params): for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision: # Set up to leverage grad copy elision:
param.grad = None param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler): def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash if self.materialize_master_grads:
post_backward_with_master_weights(self, scaler)
else:
# TODO: handle gradient clipping and removal of any lingering scale here.
stash = self._amp_stash
self._amp_lazy_init() self._amp_lazy_init()
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)) (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads) post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_no_master_weights_FusedSGD(self): def prepare_backward_no_master_weights_FusedSGD(self):
......
...@@ -51,7 +51,8 @@ class FusedSGD(Optimizer): ...@@ -51,7 +51,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0, def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, weight_decay=0, nesterov=False,
wd_after_momentum=False): wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
...@@ -67,6 +68,8 @@ class FusedSGD(Optimizer): ...@@ -67,6 +68,8 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
self.scale = 1.0
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
...@@ -130,18 +133,30 @@ class FusedSGD(Optimizer): ...@@ -130,18 +133,30 @@ class FusedSGD(Optimizer):
if explicit_master_params: if explicit_master_params:
stash = self._amp_stash stash = self._amp_stash
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets= [[fp16_model_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params], if materialize_master_grads:
[fp32_grads, fp32_params, fp32_momentums]] fp16_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else: else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
...@@ -168,6 +183,7 @@ class FusedSGD(Optimizer): ...@@ -168,6 +183,7 @@ class FusedSGD(Optimizer):
group['lr'], group['lr'],
nesterov, nesterov,
first_run, first_run,
self.wd_after_momentum) self.wd_after_momentum,
self.scale)
return loss return loss
...@@ -16,7 +16,8 @@ void multi_tensor_sgd_cuda( ...@@ -16,7 +16,8 @@ void multi_tensor_sgd_cuda(
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run, bool first_run,
bool wd_after_momentum); bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda( void multi_tensor_axpby_cuda(
int chunk_size, int chunk_size,
......
...@@ -38,7 +38,8 @@ struct SGDFunctor ...@@ -38,7 +38,8 @@ struct SGDFunctor
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run, bool first_run,
bool wd_after_momentum) bool wd_after_momentum,
float scale)
{ {
// Early exit if we don't need to do anything // Early exit if we don't need to do anything
if (*noop_gmem) return; if (*noop_gmem) return;
...@@ -82,7 +83,7 @@ struct SGDFunctor ...@@ -82,7 +83,7 @@ struct SGDFunctor
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ {
incoming_grads[ii] = static_cast<float>(grad_in[i]); incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]); incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]); incoming_moms[ii] = static_cast<float>(mom_in[i]);
} }
...@@ -146,7 +147,8 @@ void multi_tensor_sgd_cuda( ...@@ -146,7 +147,8 @@ void multi_tensor_sgd_cuda(
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run, bool first_run,
bool wd_after_momentum) bool wd_after_momentum,
float scale)
{ {
auto num_tensors = tensor_lists.size(); auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type(); auto grad_type = tensor_lists[0][0].scalar_type();
...@@ -178,7 +180,8 @@ void multi_tensor_sgd_cuda( ...@@ -178,7 +180,8 @@ void multi_tensor_sgd_cuda(
lr, lr,
nesterov, nesterov,
first_run, first_run,
wd_after_momentum); wd_after_momentum,
scale);
} }
// Case 2. fp16, fp32, fp32, No // Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half && // else if (grad_type == at::ScalarType::Half &&
...@@ -215,7 +218,8 @@ void multi_tensor_sgd_cuda( ...@@ -215,7 +218,8 @@ void multi_tensor_sgd_cuda(
lr, lr,
nesterov, nesterov,
first_run, first_run,
wd_after_momentum); wd_after_momentum,
scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Half && else if(grad_type == at::ScalarType::Half &&
...@@ -234,7 +238,8 @@ void multi_tensor_sgd_cuda( ...@@ -234,7 +238,8 @@ void multi_tensor_sgd_cuda(
lr, lr,
nesterov, nesterov,
first_run, first_run,
wd_after_momentum); wd_after_momentum,
scale);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment