Commit b265b0b5 authored by Simon Layton's avatar Simon Layton
Browse files

Fix dispatch, add wd after momentum option

Fix dispatch where we have a parameter group with multiple
combinations of types
Optionally apply weight decay after momentum
parent ac74f345
...@@ -50,7 +50,8 @@ class SGD(Optimizer): ...@@ -50,7 +50,8 @@ class SGD(Optimizer):
""" """
def __init__(self, params, lr=required, momentum=0, dampening=0, def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False): weight_decay=0, nesterov=False,
wd_after_momentum=False):
if lr is not required and lr < 0.0: if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
...@@ -64,6 +65,8 @@ class SGD(Optimizer): ...@@ -64,6 +65,8 @@ class SGD(Optimizer):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults) super(SGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
# Skip buffer # Skip buffer
...@@ -111,17 +114,62 @@ class SGD(Optimizer): ...@@ -111,17 +114,62 @@ class SGD(Optimizer):
first_run = False first_run = False
momentums.append(param_state['momentum_buffer']) momentums.append(param_state['momentum_buffer'])
# launch update using multi tensor applier # We have all parameters now, split them into appropriate groups for
# modifies weight and momentum values inplace. # parallel execution, following the 4 possible combos that the underlying
multi_tensor_applier( # kernels support:
self.multi_tensor_sgd, # grad_type, param_type, momentum_type, requires_fp16_copy
self._dummy_overflow_buf, # 1. fp16, fp16, fp16, No
[grads, params, momentums], # 2. fp16, fp32, fp32, No
weight_decay, # 3. fp16, fp32, fp32, Yes
momentum, # 4. fp32, fp32, fp32, No
dampening, # As in the kernel, easier to hardcode these options
group['lr'],
nesterov, # Store only indices into the weight / grad / momentum lists
first_run) # { gradient-type : { param-type : List } | List }
param_sets = { 'fp16' : { 'fp16' : [], 'fp32' : [] }, 'fp32' : [] }
for i, (g, p) in enumerate(zip(grads, params)):
if g.dtype == torch.float16:
# fp16 grads, fp16 params
if p.dtype == torch.float16:
param_sets['fp16']['fp16'].append(i)
# fp16 grads, fp32 params
elif p.dtype == torch.float32:
param_sets['fp16']['fp32'].append(i)
else:
raise RuntimeError('fp16 gradients need either fp16 or fp32 weights')
# fp32 grads, fp32 params
elif g.dtype == torch.float32:
param_sets['fp32'].append(i)
else:
raise RuntimeError('gradients must either be fp16 or fp32')
def launch_sgd_set(param_set):
local_params, local_grads, local_momentums = [], [], []
if len(param_set) == 0:
return
# launch update using multi tensor applier
# modifies weight and momentum values inplace.
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
# Note: Need to do this as list comprehensions otherwise
# things don't seem to update properly.
[[grads[i] for i in param_set],
[params[i] for i in param_set],
[momentums[i] for i in param_set]],
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum)
# Explicitly go over the cases
launch_sgd_set(param_sets['fp16']['fp16'])
launch_sgd_set(param_sets['fp16']['fp32'])
launch_sgd_set(param_sets['fp32'])
return loss return loss
...@@ -15,7 +15,8 @@ void multi_tensor_sgd_cuda( ...@@ -15,7 +15,8 @@ void multi_tensor_sgd_cuda(
float dampening, float dampening,
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run); bool first_run,
bool wd_after_momentum);
void scale_check_overflow_cuda( void scale_check_overflow_cuda(
const at::Tensor& grads, const at::Tensor& grads,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
* lr : learning rate (scalar) * lr : learning rate (scalar)
* nesterov : enable nesterov (bool) * nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init * first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/ **/
template<int N, typename T_grad, typename T_weight> template<int N, typename T_grad, typename T_weight>
struct SGDFunctor struct SGDFunctor
...@@ -36,7 +37,8 @@ struct SGDFunctor ...@@ -36,7 +37,8 @@ struct SGDFunctor
float dampening, float dampening,
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run) bool first_run,
bool wd_after_momentum)
{ {
// Early exit if we don't need to do anything // Early exit if we don't need to do anything
if (*noop_gmem) return; if (*noop_gmem) return;
...@@ -93,8 +95,8 @@ struct SGDFunctor ...@@ -93,8 +95,8 @@ struct SGDFunctor
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) { if(i < n && i < chunk_size) {
// apply weight decay // apply weight decay before momentum if necessary
if (wd != 0.f) { if (wd != 0.f && !wd_after_momentum) {
incoming_grads[ii] += wd * incoming_weights[ii]; incoming_grads[ii] += wd * incoming_weights[ii];
} }
if (momentum != 0.f) { if (momentum != 0.f) {
...@@ -109,6 +111,11 @@ struct SGDFunctor ...@@ -109,6 +111,11 @@ struct SGDFunctor
} }
} }
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) {
incoming_grads[ii] += wd * incoming_weights[ii];
}
// adjust the weight and write out // adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
...@@ -136,7 +143,8 @@ void multi_tensor_sgd_cuda( ...@@ -136,7 +143,8 @@ void multi_tensor_sgd_cuda(
float dampening, float dampening,
float lr, float lr,
bool nesterov, bool nesterov,
bool first_run) bool first_run,
bool wd_after_momentum)
{ {
auto num_tensors = tensor_lists.size(); auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].type().scalarType(); auto grad_type = tensor_lists[0][0].type().scalarType();
...@@ -167,7 +175,8 @@ void multi_tensor_sgd_cuda( ...@@ -167,7 +175,8 @@ void multi_tensor_sgd_cuda(
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run,
wd_after_momentum);
} }
// Case 2. fp16, fp32, fp32, No // Case 2. fp16, fp32, fp32, No
else if (grad_type == at::ScalarType::Half && else if (grad_type == at::ScalarType::Half &&
...@@ -184,7 +193,8 @@ void multi_tensor_sgd_cuda( ...@@ -184,7 +193,8 @@ void multi_tensor_sgd_cuda(
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run,
wd_after_momentum);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && else if (grad_type == at::ScalarType::Half &&
...@@ -201,7 +211,8 @@ void multi_tensor_sgd_cuda( ...@@ -201,7 +211,8 @@ void multi_tensor_sgd_cuda(
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run,
wd_after_momentum);
} }
// Case 4. fp32, fp32, fp32, No // Case 4. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && else if (grad_type == at::ScalarType::Float &&
...@@ -218,7 +229,8 @@ void multi_tensor_sgd_cuda( ...@@ -218,7 +229,8 @@ void multi_tensor_sgd_cuda(
dampening, dampening,
lr, lr,
nesterov, nesterov,
first_run); first_run,
wd_after_momentum);
} }
else { else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment