Commit a2799893 authored by Simon Layton's avatar Simon Layton
Browse files

Handle fp16 weights case without forcing fp16 math

Incorrect types used in a few places
parent 75c8a97a
......@@ -24,7 +24,7 @@
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
**/
template<int N, typename T_grad, typename T>
template<int N, typename T_grad, typename T_weight>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
......@@ -53,24 +53,24 @@ struct SGDFunctor
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T* weight_in = (T*)tl.addresses[1][tensor_loc];
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size;
T* mom_in = (T*)tl.addresses[2][tensor_loc];
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
half *model_weights_out = nullptr;
at::Half *model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (half*)tl.addresses[3][tensor_loc];
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
T incoming_grads[ILP];
T incoming_weights[ILP];
T incoming_moms[ILP];
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
......@@ -83,9 +83,9 @@ struct SGDFunctor
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
incoming_grads[ii] = static_cast<T>(grad_in[i]);
incoming_weights[ii] = static_cast<T>(weight_in[i]);
incoming_moms[ii] = static_cast<T>(mom_in[i]);
incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
// note for clarification to future michael:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment