Commit 9b656f46 authored by Phil Wang's avatar Phil Wang
Browse files

follow advice of Tim to fix update of momentum vs parameters in blockwise 8 bit

parent 369a51c4
...@@ -1708,6 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1708,6 +1708,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
s1_vals[j] = (s1_vals[j]*beta1) + g_val; s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break; break;
case LION: case LION:
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break; break;
case RMSPROP: case RMSPROP:
...@@ -1748,7 +1749,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -1748,7 +1749,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break; break;
case LION: case LION:
p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break; break;
case RMSPROP: case RMSPROP:
g_val = g_vals[j]; g_val = g_vals[j];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment