Commit 3b32c401 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixed bounds checking

parent 2c63ba91
...@@ -63,7 +63,7 @@ class FusedSGD(Optimizer): ...@@ -63,7 +63,7 @@ class FusedSGD(Optimizer):
weight_decay=weight_decay, nesterov=nesterov) weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults) super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
...@@ -80,8 +80,9 @@ class FusedSGD(Optimizer): ...@@ -80,8 +80,9 @@ class FusedSGD(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
def get_momentums(params): def get_momentums(self, params):
momentums = [] momentums = []
first_run = True
for p in params: for p in params:
param_state = self.state[p] param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have # torch.optim.SGD initializes momentum in the main loop, we have
...@@ -153,7 +154,7 @@ class FusedSGD(Optimizer): ...@@ -153,7 +154,7 @@ class FusedSGD(Optimizer):
launch_sets = [[fp16_grads, fp16_params, fp16_momentums], launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]] [fp32_grads, fp32_params, fp32_momentums]]
for launch_set, first_run in zip(launch_sets, first_runs): for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1]) assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2]) assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0: if len(launch_set[0]) > 0:
......
...@@ -57,7 +57,8 @@ struct SGDFunctor ...@@ -57,7 +57,8 @@ struct SGDFunctor
mom_in += chunk_idx*chunk_size; mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr; at::Half *model_weights_out = nullptr;
if (N == 4) { if(N == 4)
{
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size; model_weights_out += chunk_idx*chunk_size;
} }
...@@ -80,10 +81,12 @@ struct SGDFunctor ...@@ -80,10 +81,12 @@ struct SGDFunctor
incoming_moms[ii] = 0; incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i]); incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]); incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]); incoming_moms[ii] = static_cast<float>(mom_in[i]);
} }
}
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling // From a pure memory dependency perspective, there's likely no point unrolling
...@@ -94,47 +97,44 @@ struct SGDFunctor ...@@ -94,47 +97,44 @@ struct SGDFunctor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) { if(i < n && i < chunk_size)
{
// apply weight decay before momentum if necessary // apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum) { if(wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii]; incoming_grads[ii] += wd * incoming_weights[ii];
}
if (momentum != 0.f) { if(momentum != 0.f)
if (!first_run) { {
if(!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
} else { else
// initialize momentume to current incoming grads // initialize momentume to current incoming grads
incoming_moms[ii] = incoming_grads[ii]; incoming_moms[ii] = incoming_grads[ii];
}
if (nesterov) { if(nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii]; incoming_grads[ii] += momentum * incoming_moms[ii];
} else { else
incoming_grads[ii] = incoming_moms[ii]; incoming_grads[ii] = incoming_moms[ii];
} }
}
// Apply WD after momentum if desired // Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) { if(wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii]; incoming_grads[ii] += wd * incoming_weights[ii];
}
// adjust the weight and write out // adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights // if necessary, write out an fp16 copy of the weights
if (N == 4) { if(N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]); model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
}
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) { if(momentum != 0.f)
mom_in[i] = incoming_moms[ii]; mom_in[i] = incoming_moms[ii];
} }
} }
} }
} }
}
}; };
void multi_tensor_sgd_cuda( void multi_tensor_sgd_cuda(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment