Commit a2799893 authored by Simon Layton's avatar Simon Layton
Browse files

Handle fp16 weights case without forcing fp16 math

Incorrect types used in a few places
parent 75c8a97a
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
* nesterov : enable nesterov (bool) * nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init * first run : necessary for proper momentum handling & init
**/ **/
template<int N, typename T_grad, typename T> template<int N, typename T_grad, typename T_weight>
struct SGDFunctor struct SGDFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -53,24 +53,24 @@ struct SGDFunctor ...@@ -53,24 +53,24 @@ struct SGDFunctor
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc]; T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size; grad_in += chunk_idx*chunk_size;
T* weight_in = (T*)tl.addresses[1][tensor_loc]; T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size; weight_in += chunk_idx*chunk_size;
T* mom_in = (T*)tl.addresses[2][tensor_loc]; T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size; mom_in += chunk_idx*chunk_size;
half *model_weights_out = nullptr; at::Half *model_weights_out = nullptr;
if (N == 4) { if (N == 4) {
model_weights_out = (half*)tl.addresses[3][tensor_loc]; model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size; model_weights_out += chunk_idx*chunk_size;
} }
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads // Non-divergent exit condition for the __syncthreads
T incoming_grads[ILP]; float incoming_grads[ILP];
T incoming_weights[ILP]; float incoming_weights[ILP];
T incoming_moms[ILP]; float incoming_moms[ILP];
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
...@@ -83,9 +83,9 @@ struct SGDFunctor ...@@ -83,9 +83,9 @@ struct SGDFunctor
incoming_moms[ii] = 0; incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
incoming_grads[ii] = static_cast<T>(grad_in[i]); incoming_grads[ii] = static_cast<float>(grad_in[i]);
incoming_weights[ii] = static_cast<T>(weight_in[i]); incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<T>(mom_in[i]); incoming_moms[ii] = static_cast<float>(mom_in[i]);
} }
// note for clarification to future michael: // note for clarification to future michael:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment