Commit 75139ca3 authored by Michael Carilli's avatar Michael Carilli
Browse files

let's see

parent e0f2ffa5
......@@ -11,6 +11,20 @@ class AmpOptimizerState(object):
pass
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def lazy_init_with_master_weights(self):
stash = self._amp_stash
stash.fp16_groups = []
......@@ -277,6 +291,8 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# implementations until I have them both working.
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash
......@@ -284,60 +300,33 @@ def prepare_backward_with_master_weights_FusedSGD(self):
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
stash.grad_norms = norm_groups
for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_no_master_weights_FusedSGD(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
prepare_backward_no_master_weights(self)
def post_backward_no_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
post_backward_no_master_weights(self, scaler)
def _process_optimizer(optimizer, properties):
......
......@@ -80,6 +80,22 @@ class FusedSGD(Optimizer):
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(params):
momentums = []
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
......@@ -91,73 +107,59 @@ class FusedSGD(Optimizer):
if closure is not None:
loss = closure()
for group in self.param_groups:
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
params = [p for p in group['params'] if p is not None]
grads = [p.grad for p in params]
momentums = []
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
# We have all parameters now, split them into appropriate groups for
# parallel execution, following the 4 possible combos that the underlying
# kernels support:
# grad_type, param_type, momentum_type, requires_fp16_copy
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp16, fp32, fp32, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
# 4. fp32, fp32, fp32, No
# As in the kernel, easier to hardcode these options
# Store only indices into the weight / grad / momentum lists
# { gradient-type : { param-type : List } | List }
param_sets = { 'fp16' : { 'fp16' : [], 'fp32' : [] }, 'fp32' : [] }
for i, (g, p) in enumerate(zip(grads, params)):
if g.dtype == torch.float16:
# fp16 grads, fp16 params
if p.dtype == torch.float16:
param_sets['fp16']['fp16'].append(i)
# fp16 grads, fp32 params
elif p.dtype == torch.float32:
param_sets['fp16']['fp32'].append(i)
else:
raise RuntimeError('fp16 gradients need either fp16 or fp32 weights')
# fp32 grads, fp32 params
elif g.dtype == torch.float32:
param_sets['fp32'].append(i)
else:
raise RuntimeError('gradients must either be fp16 or fp32')
def launch_sgd_set(param_set):
local_params, local_grads, local_momentums = [], [], []
if len(param_set) == 0:
return
# launch update using multi tensor applier
# modifies weight and momentum values inplace.
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params
stash = self._amp_stash
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets= [[fp16_model_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params],
[fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for launch_set, first_run in zip(launch_sets, first_runs):
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
# Note: Need to do this as list comprehensions otherwise
# things don't seem to update properly.
[[grads[i] for i in param_set],
[params[i] for i in param_set],
[momentums[i] for i in param_set]],
launch_set,
weight_decay,
momentum,
dampening,
......@@ -166,9 +168,4 @@ class FusedSGD(Optimizer):
first_run,
self.wd_after_momentum)
# Explicitly go over the cases
launch_sgd_set(param_sets['fp16']['fp16'])
launch_sgd_set(param_sets['fp16']['fp32'])
launch_sgd_set(param_sets['fp32'])
return loss
......@@ -150,23 +150,23 @@ void multi_tensor_sgd_cuda(
bool wd_after_momentum)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].type().scalarType();
auto weight_type = tensor_lists[0][0].type().scalarType();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
// We have 4 potentials to handle here, in terms of
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp16, fp32, fp32, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, No
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3) {
if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
......@@ -182,15 +182,34 @@ void multi_tensor_sgd_cuda(
wd_after_momentum);
}
// Case 2. fp16, fp32, fp32, No
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::Half, float>(),
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
......@@ -200,9 +219,10 @@ void multi_tensor_sgd_cuda(
wd_after_momentum);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4) {
else if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
......@@ -217,25 +237,8 @@ void multi_tensor_sgd_cuda(
first_run,
wd_after_momentum);
}
// Case 4. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum);
}
else {
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment