Commit c99b44f7 authored by Phil Wang's avatar Phil Wang
Browse files

do the epsilon beta2 switcharoo within the cuda code, and not within the...

do the epsilon beta2 switcharoo within the cuda code, and not within the python class (so that the state dict still makes sense)
parent 8618bed0
...@@ -18,13 +18,12 @@ class Lion(Optimizer1State): ...@@ -18,13 +18,12 @@ class Lion(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
beta1, beta2 = betas
super().__init__( super().__init__(
"lion", "lion",
params, params,
lr, lr,
(beta1, 0.), (beta1, beta2),
beta2, 0.,
weight_decay, weight_decay,
optim_bits, optim_bits,
args, args,
...@@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State): ...@@ -46,13 +45,12 @@ class Lion8bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
beta1, beta2 = betas
super().__init__( super().__init__(
"lion", "lion",
params, params,
lr, lr,
(beta1, 0.), (beta1, beta2),
beta2, 0.,
weight_decay, weight_decay,
8, 8,
args, args,
...@@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State): ...@@ -74,13 +72,12 @@ class Lion32bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
beta1, beta2 = betas
super().__init__( super().__init__(
"lion", "lion",
params, params,
lr, lr,
(beta1, 0.), betas,
beta2, 0.,
weight_decay, weight_decay,
32, 32,
args, args,
......
...@@ -132,13 +132,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, ...@@ -132,13 +132,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
break; break;
case LION: case LION:
// in lion, the momentum update after the parameter update // in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
break; break;
...@@ -183,12 +183,12 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, ...@@ -183,12 +183,12 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
break; break;
case LION: case LION:
// in lion, the momentum update happens after the parameter update // in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
default: default:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment