Commit cac061a1 authored by Simon Layton's avatar Simon Layton
Browse files

Code cleanup, add fused fp16 read / write

Fuse in fp16 gradient -> fp32 convert
Additional option fp16 weight copy written out
parent cadad920
...@@ -3,8 +3,6 @@ from torch.optim.optimizer import Optimizer, required ...@@ -3,8 +3,6 @@ from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
class SGD(Optimizer): class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum). r"""Implements stochastic gradient descent (optionally with momentum).
...@@ -66,8 +64,13 @@ class SGD(Optimizer): ...@@ -66,8 +64,13 @@ class SGD(Optimizer):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults) super(SGD, self).__init__(params, defaults)
# Skip buffer if multi_tensor_applier.available:
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optim.SGD requires cuda extensions')
def __setstate__(self, state): def __setstate__(self, state):
super(SGD, self).__setstate__(state) super(SGD, self).__setstate__(state)
...@@ -96,6 +99,9 @@ class SGD(Optimizer): ...@@ -96,6 +99,9 @@ class SGD(Optimizer):
momentums = [] momentums = []
for p in params: for p in params:
param_state = self.state[p] param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
first_run = True first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
...@@ -105,9 +111,10 @@ class SGD(Optimizer): ...@@ -105,9 +111,10 @@ class SGD(Optimizer):
first_run = False first_run = False
momentums.append(param_state['momentum_buffer']) momentums.append(param_state['momentum_buffer'])
# launch update using multi tensor apply # launch update using multi tensor applier
# modifies weight and momentum values inplace.
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_sgd, self.multi_tensor_sgd,
self._dummy_overflow_buf, self._dummy_overflow_buf,
[grads, params, momentums], [grads, params, momentums],
weight_decay, weight_decay,
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
/** /**
* Perform fused SGD on multiple buffers * Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients * tl[0] : gradients
* tl[1] : weights * tl[1] : weights
* tl[2] : momentum buffers * tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar) * wd : weight_decay (scalar)
* momentum : momentum (scalar) * momentum : momentum (scalar)
* dampening : momentum dampening (scalar) * dampening : momentum dampening (scalar)
...@@ -22,13 +24,13 @@ ...@@ -22,13 +24,13 @@
* nesterov : enable nesterov (bool) * nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init * first run : necessary for proper momentum handling & init
**/ **/
template<typename T> template<int N, typename T_grad, typename T>
struct SGDFunctor struct SGDFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorList<3>& tl, TensorList<N>& tl,
float wd, float wd,
float momentum, float momentum,
float dampening, float dampening,
...@@ -48,7 +50,7 @@ struct SGDFunctor ...@@ -48,7 +50,7 @@ struct SGDFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
T* grad_in = (T*)tl.addresses[0][tensor_loc]; T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size; grad_in += chunk_idx*chunk_size;
T* weight_in = (T*)tl.addresses[1][tensor_loc]; T* weight_in = (T*)tl.addresses[1][tensor_loc];
...@@ -57,12 +59,18 @@ struct SGDFunctor ...@@ -57,12 +59,18 @@ struct SGDFunctor
T* mom_in = (T*)tl.addresses[2][tensor_loc]; T* mom_in = (T*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size; mom_in += chunk_idx*chunk_size;
half *model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads // Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP]; T incoming_grads[ILP];
float incoming_weights[ILP]; T incoming_weights[ILP];
float incoming_moms[ILP]; T incoming_moms[ILP];
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
...@@ -75,9 +83,9 @@ struct SGDFunctor ...@@ -75,9 +83,9 @@ struct SGDFunctor
incoming_moms[ii] = 0; incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
incoming_grads[ii] = static_cast<float>(grad_in[i]); incoming_grads[ii] = static_cast<T>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]); incoming_weights[ii] = static_cast<T>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]); incoming_moms[ii] = static_cast<T>(mom_in[i]);
} }
// note for clarification to future michael: // note for clarification to future michael:
...@@ -107,6 +115,11 @@ struct SGDFunctor ...@@ -107,6 +115,11 @@ struct SGDFunctor
// adjust the weight and write out // adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) {
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
}
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) { if (momentum != 0.f) {
mom_in[i] = incoming_moms[ii]; mom_in[i] = incoming_moms[ii];
...@@ -137,20 +150,79 @@ void multi_tensor_sgd_cuda( ...@@ -137,20 +150,79 @@ void multi_tensor_sgd_cuda(
bool nesterov, bool nesterov,
bool first_run) bool first_run)
{ {
multi_tensor_apply<3>( auto num_tensors = tensor_lists.size();
BLOCK_SIZE,
chunk_size, switch (num_tensors) {
noop_flag, case 3:
tensor_lists, switch (tensor_lists[0][0].type().scalarType()) {
SGDFunctor<float>(), case at::ScalarType::Half:
wd, multi_tensor_apply<3>(
momentum, BLOCK_SIZE,
dampening, chunk_size,
lr, noop_flag,
nesterov, tensor_lists,
first_run); SGDFunctor<3, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
break;
case 4:
switch (tensor_lists[0][0].type().scalarType()) {
case at::ScalarType::Half:
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
case at::ScalarType::Float:
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run);
break;
default:
AT_ERROR("multi_tensor_sgd only takes Half and Float gradients, given: ", tensor_lists[0][0].type().scalarType());
}
default:
AT_ERROR("multi_tensor_sgd takes either 3 or 4 sets of tensors, given ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment