Commit 75139ca3 authored by Michael Carilli's avatar Michael Carilli
Browse files

let's see

parent e0f2ffa5
...@@ -11,6 +11,20 @@ class AmpOptimizerState(object): ...@@ -11,6 +11,20 @@ class AmpOptimizerState(object):
pass pass
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def lazy_init_with_master_weights(self): def lazy_init_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
stash.fp16_groups = [] stash.fp16_groups = []
...@@ -277,6 +291,8 @@ def post_backward_no_master_weights_FusedAdam(self, scaler): ...@@ -277,6 +291,8 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# implementations until I have them both working. # implementations until I have them both working.
##################################################################################### #####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def prepare_backward_with_master_weights_FusedSGD(self): def prepare_backward_with_master_weights_FusedSGD(self):
stash = self._amp_stash stash = self._amp_stash
...@@ -284,60 +300,33 @@ def prepare_backward_with_master_weights_FusedSGD(self): ...@@ -284,60 +300,33 @@ def prepare_backward_with_master_weights_FusedSGD(self):
self._lazy_init_maybe_master_weights() self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True stash.lazy_init_called = True
for i, param in enumerate(stash.all_fp16_params):
stash.all_fp16_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
for i, param in enumerate(stash.all_fp32_from_fp32_params):
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
# Set up to leverage grad copy elision:
param.grad = None
def post_backward_with_master_weights_FusedSGD(self, scaler): def post_backward_with_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = [] split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
skip = False (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash))
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups for params, stashed_grads in split_types:
post_backward_models_are_masters(scaler, params, stashed_grads)
def prepare_backward_no_master_weights_FusedSGD(self): def prepare_backward_no_master_weights_FusedSGD(self):
stash = self._amp_stash prepare_backward_no_master_weights(self)
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_FusedSGD(self, scaler): def post_backward_no_master_weights_FusedSGD(self, scaler):
stash = self._amp_stash post_backward_no_master_weights(self, scaler)
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self):
stash = self._amp_stash
if multi_tensor_applier.available:
if len(stash.all_fp16_params) > 0:
multi_tensor_applier(
stash.multi_tensor_scale,
stash.dummy_overflow_buf,
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
1.0)
else:
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def _process_optimizer(optimizer, properties): def _process_optimizer(optimizer, properties):
......
...@@ -80,6 +80,22 @@ class FusedSGD(Optimizer): ...@@ -80,6 +80,22 @@ class FusedSGD(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
def get_momentums(params):
momentums = []
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -91,73 +107,59 @@ class FusedSGD(Optimizer): ...@@ -91,73 +107,59 @@ class FusedSGD(Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
for group in self.param_groups: explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
momentum = group['momentum'] momentum = group['momentum']
dampening = group['dampening'] dampening = group['dampening']
nesterov = group['nesterov'] nesterov = group['nesterov']
params = [p for p in group['params'] if p is not None]
grads = [p.grad for p in params]
momentums = []
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
# We have all parameters now, split them into appropriate groups for # For each group, there are 3 possible combinations we need to consider:
# parallel execution, following the 4 possible combos that the underlying # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# kernels support:
# grad_type, param_type, momentum_type, requires_fp16_copy
# 1. fp16, fp16, fp16, No # 1. fp16, fp16, fp16, No
# 2. fp16, fp32, fp32, No # 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes # 3. fp16, fp32, fp32, Yes
# 4. fp32, fp32, fp32, No
# As in the kernel, easier to hardcode these options first_runs = [True, True]
# Store only indices into the weight / grad / momentum lists # I think a bit of code divergence in exchange for naming clarity is worthwhile
# { gradient-type : { param-type : List } | List } if explicit_master_params
param_sets = { 'fp16' : { 'fp16' : [], 'fp32' : [] }, 'fp32' : [] } stash = self._amp_stash
for i, (g, p) in enumerate(zip(grads, params)): fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
if g.dtype == torch.float16: fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
# fp16 grads, fp16 params fp32_from_fp16_params = [p for i, p in enumerate(
if p.dtype == torch.float16: stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
param_sets['fp16']['fp16'].append(i) fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
# fp16 grads, fp32 params
elif p.dtype == torch.float32: fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
param_sets['fp16']['fp32'].append(i) fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None)]
else: fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
raise RuntimeError('fp16 gradients need either fp16 or fp32 weights')
# fp32 grads, fp32 params launch_sets= [[fp16_model_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_model_params],
elif g.dtype == torch.float32: [fp32_grads, fp32_params, fp32_momentums]]
param_sets['fp32'].append(i)
else: else:
raise RuntimeError('gradients must either be fp16 or fp32') fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
def launch_sgd_set(param_set): fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
local_params, local_grads, local_momentums = [], [], [] fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
if len(param_set) == 0: fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
return
# launch update using multi tensor applier launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
# modifies weight and momentum values inplace. [fp32_grads, fp32_params, fp32_momentums]]
for launch_set, first_run in zip(launch_sets, first_runs):
multi_tensor_applier( multi_tensor_applier(
self.multi_tensor_sgd, self.multi_tensor_sgd,
self._dummy_overflow_buf, self._dummy_overflow_buf,
# Note: Need to do this as list comprehensions otherwise # Note: Need to do this as list comprehensions otherwise
# things don't seem to update properly. # things don't seem to update properly.
[[grads[i] for i in param_set], launch_set,
[params[i] for i in param_set],
[momentums[i] for i in param_set]],
weight_decay, weight_decay,
momentum, momentum,
dampening, dampening,
...@@ -166,9 +168,4 @@ class FusedSGD(Optimizer): ...@@ -166,9 +168,4 @@ class FusedSGD(Optimizer):
first_run, first_run,
self.wd_after_momentum) self.wd_after_momentum)
# Explicitly go over the cases
launch_sgd_set(param_sets['fp16']['fp16'])
launch_sgd_set(param_sets['fp16']['fp32'])
launch_sgd_set(param_sets['fp32'])
return loss return loss
...@@ -150,23 +150,23 @@ void multi_tensor_sgd_cuda( ...@@ -150,23 +150,23 @@ void multi_tensor_sgd_cuda(
bool wd_after_momentum) bool wd_after_momentum)
{ {
auto num_tensors = tensor_lists.size(); auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].type().scalarType(); auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[0][0].type().scalarType(); auto weight_type = tensor_lists[1][0].scalar_type();
// We have 4 potentials to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
// 2. fp16, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, No
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
num_tensors == 3) { num_tensors == 3)
{
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -182,15 +182,34 @@ void multi_tensor_sgd_cuda( ...@@ -182,15 +182,34 @@ void multi_tensor_sgd_cuda(
wd_after_momentum); wd_after_momentum);
} }
// Case 2. fp16, fp32, fp32, No // Case 2. fp16, fp32, fp32, No
else if (grad_type == at::ScalarType::Half && // else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float && weight_type == at::ScalarType::Float &&
num_tensors == 3) { num_tensors == 3)
{
multi_tensor_apply<3>( multi_tensor_apply<3>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
SGDFunctor<3, at::Half, float>(), SGDFunctor<3, float, float>(),
wd, wd,
momentum, momentum,
dampening, dampening,
...@@ -200,9 +219,10 @@ void multi_tensor_sgd_cuda( ...@@ -200,9 +219,10 @@ void multi_tensor_sgd_cuda(
wd_after_momentum); wd_after_momentum);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && else if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float && weight_type == at::ScalarType::Float &&
num_tensors == 4) { num_tensors == 4)
{
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -217,25 +237,8 @@ void multi_tensor_sgd_cuda( ...@@ -217,25 +237,8 @@ void multi_tensor_sgd_cuda(
first_run, first_run,
wd_after_momentum); wd_after_momentum);
} }
// Case 4. fp32, fp32, fp32, No else
else if (grad_type == at::ScalarType::Float && {
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum);
}
else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment