"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "ceeab4692263a3e058b3afec4269ebef8e40a038"
Commit 8618bed0 authored by Phil Wang's avatar Phil Wang
Browse files

swap the order in which momentum and parameters are updated in ops.cu

parent c5582724
...@@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -120,8 +120,6 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
case LION:
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
...@@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -132,6 +130,18 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
break;
} }
} }
...@@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -164,7 +174,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
case LION:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
...@@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -172,6 +181,16 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
default: default:
break; break;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment