Commit 3b32c401 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixed bounds checking

parent 2c63ba91
......@@ -63,7 +63,7 @@ class FusedSGD(Optimizer):
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
......@@ -80,8 +80,9 @@ class FusedSGD(Optimizer):
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(params):
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
......@@ -153,7 +154,7 @@ class FusedSGD(Optimizer):
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for launch_set, first_run in zip(launch_sets, first_runs):
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
......
......@@ -57,7 +57,8 @@ struct SGDFunctor
mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr;
if (N == 4) {
if(N == 4)
{
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
......@@ -80,9 +81,11 @@ struct SGDFunctor
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
......@@ -94,43 +97,40 @@ struct SGDFunctor
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) {
if(i < n && i < chunk_size)
{
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum) {
if(wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
}
if (momentum != 0.f) {
if (!first_run) {
if(momentum != 0.f)
{
if(!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
} else {
else
// initialize momentume to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
}
if (nesterov) {
if(nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
} else {
else
incoming_grads[ii] = incoming_moms[ii];
}
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) {
if(wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
}
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) {
if(N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
}
// also write out the new momentum
if (momentum != 0.f) {
if(momentum != 0.f)
mom_in[i] = incoming_moms[ii];
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment