Commit cac061a1 authored by Simon Layton's avatar Simon Layton
Browse files

Code cleanup, add fused fp16 read / write

Fuse in fp16 gradient -> fp32 convert
Additional option fp16 weight copy written out
parent cadad920
......@@ -3,8 +3,6 @@ from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
......@@ -66,8 +64,13 @@ class SGD(Optimizer):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optim.SGD requires cuda extensions')
def __setstate__(self, state):
super(SGD, self).__setstate__(state)
......@@ -96,6 +99,9 @@ class SGD(Optimizer):
momentums = []
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
......@@ -105,9 +111,10 @@ class SGD(Optimizer):
first_run = False
momentums.append(param_state['momentum_buffer'])
# launch update using multi tensor apply
# launch update using multi tensor applier
# modifies weight and momentum values inplace.
multi_tensor_applier(
amp_C.multi_tensor_sgd,
self.multi_tensor_sgd,
self._dummy_overflow_buf,
[grads, params, momentums],
weight_decay,
......
......@@ -12,9 +12,11 @@
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
......@@ -22,13 +24,13 @@
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
**/
template<typename T>
template<int N, typename T_grad, typename T>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorList<3>& tl,
TensorList<N>& tl,
float wd,
float momentum,
float dampening,
......@@ -48,7 +50,7 @@ struct SGDFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* grad_in = (T*)tl.addresses[0][tensor_loc];
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T* weight_in = (T*)tl.addresses[1][tensor_loc];
......@@ -57,12 +59,18 @@ struct SGDFunctor
T* mom_in = (T*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
half *model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
T incoming_grads[ILP];
T incoming_weights[ILP];
T incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
......@@ -75,9 +83,9 @@ struct SGDFunctor
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
incoming_grads[ii] = static_cast<T>(grad_in[i]);
incoming_weights[ii] = static_cast<T>(weight_in[i]);
incoming_moms[ii] = static_cast<T>(mom_in[i]);
}
// note for clarification to future michael:
......@@ -107,6 +115,11 @@ struct SGDFunctor
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) {
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
}
// also write out the new momentum
if (momentum != 0.f) {
mom_in[i] = incoming_moms[ii];
......@@ -137,20 +150,79 @@ void multi_tensor_sgd_cuda(
bool nesterov,
bool first_run)
{
auto num_tensors = tensor_lists.size();
switch (num_tensors) {
case 3:
switch (tensor_lists[0][0].type().scalarType()) {
case at::ScalarType::Half:
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<float>(),
SGDFunctor<3, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
break;
case 4:
switch (tensor_lists[0][0].type().scalarType()) {
case at::ScalarType::Half:
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
default:
AT_ERROR("multi_tensor_sgd takes either 3 or 4 sets of tensors, given ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment