Commit 79b2cc28 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent bd6b1ebc
...@@ -79,7 +79,7 @@ __global__ void adam_cuda_kernel( ...@@ -79,7 +79,7 @@ __global__ void adam_cuda_kernel(
T pi[ILP]; T pi[ILP];
T gi[ILP]; T gi[ILP];
bool overflow = False; bool overflow = false;
for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
...@@ -99,7 +99,6 @@ __global__ void adam_cuda_kernel( ...@@ -99,7 +99,6 @@ __global__ void adam_cuda_kernel(
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
T scaled_grad = gi[ii]/grad_scale; T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) { if (isfinite(scaled_grad)) {
mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad; mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad;
...@@ -112,7 +111,7 @@ __global__ void adam_cuda_kernel( ...@@ -112,7 +111,7 @@ __global__ void adam_cuda_kernel(
float update = (mi[ii]/denom) + (decay*pi[ii]); float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update); pi[ii] = pi[ii] - (step_size*update);
} else { } else {
overflow = True; overflow = true;
} }
} }
...@@ -137,7 +136,7 @@ __global__ void adam_cuda_kernel( ...@@ -137,7 +136,7 @@ __global__ void adam_cuda_kernel(
} }
template <typename T, typename GRAD_T> template <typename T, typename GRAD_T>
__global__ __device__ void adam_undo_cuda_kernel( __global__ void adam_undo_cuda_kernel(
T* __restrict__ p, T* __restrict__ p,
T* __restrict__ m, T* __restrict__ m,
T* __restrict__ v, T* __restrict__ v,
...@@ -182,7 +181,6 @@ __global__ __device__ void adam_undo_cuda_kernel( ...@@ -182,7 +181,6 @@ __global__ __device__ void adam_undo_cuda_kernel(
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = j_start + i*ILP;
T scaled_grad = gi[ii]/grad_scale; T scaled_grad = gi[ii]/grad_scale;
if (isfinite(scaled_grad)) { if (isfinite(scaled_grad)) {
float denom; float denom;
...@@ -195,7 +193,7 @@ __global__ __device__ void adam_undo_cuda_kernel( ...@@ -195,7 +193,7 @@ __global__ __device__ void adam_undo_cuda_kernel(
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2; vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value. // Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step. // This can happen if we have to revert the very first step.
vii[ii] = vii[i] >= 0.0f ? vi[ii] : 0.0f; vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;
} }
} }
...@@ -252,7 +250,7 @@ struct AdamFunctor ...@@ -252,7 +250,7 @@ struct AdamFunctor
T pi[ILP]; T pi[ILP];
T gi[ILP]; T gi[ILP];
bool overflow = False; bool overflow = false;
for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) { for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
...@@ -262,7 +260,7 @@ struct AdamFunctor ...@@ -262,7 +260,7 @@ struct AdamFunctor
gi[ii] = GRAD_T(0); gi[ii] = GRAD_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x; int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) { if (j < dim) {
pi[ii] = p[j]; pi[ii] = p[j];
mi[ii] = m[j]; mi[ii] = m[j];
vi[ii] = v[j]; vi[ii] = v[j];
...@@ -285,14 +283,14 @@ struct AdamFunctor ...@@ -285,14 +283,14 @@ struct AdamFunctor
float update = (mi[ii]/denom) + (decay*pi[ii]); float update = (mi[ii]/denom) + (decay*pi[ii]);
pi[ii] = pi[ii] - (step_size*update); pi[ii] = pi[ii] - (step_size*update);
} else { } else {
overflow = True; overflow = true;
} }
} }
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x; int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) { if (j < dim) {
m[j] = mi[ii]; m[j] = mi[ii];
v[j] = vi[ii]; v[j] = vi[ii];
p[j] = pi[ii]; p[j] = pi[ii];
...@@ -352,7 +350,7 @@ struct AdamUndoFunctor ...@@ -352,7 +350,7 @@ struct AdamUndoFunctor
gi[ii] = GRAD_T(0); gi[ii] = GRAD_T(0);
int j = j_start + threadIdx.x + ii*blockDim.x; int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) { if (j < dim) {
pi[ii] = p[j]; pi[ii] = p[j];
mi[ii] = m[j]; mi[ii] = m[j];
vi[ii] = v[j]; vi[ii] = v[j];
...@@ -375,14 +373,14 @@ struct AdamUndoFunctor ...@@ -375,14 +373,14 @@ struct AdamUndoFunctor
vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2; vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2;
// Make sure round off errors don't create (small) negative value. // Make sure round off errors don't create (small) negative value.
// This can happen if we have to revert the very first step. // This can happen if we have to revert the very first step.
vii[ii] = vii[i] >= 0.0f ? vi[ii] : 0.0f; vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f;
} }
} }
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = j_start + threadIdx.x + ii*blockDim.x; int j = j_start + threadIdx.x + ii*blockDim.x;
if (j < tsize) { if (j < dim) {
m[j] = mi[ii]; m[j] = mi[ii];
v[j] = vi[ii]; v[j] = vi[ii];
p[j] = pi[ii]; p[j] = pi[ii];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment