Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b0ec20c3
Unverified
Commit
b0ec20c3
authored
Apr 11, 2023
by
Tim Dettmers
Committed by
GitHub
Apr 11, 2023
Browse files
Merge pull request #188 from lucidrains/main
Lion 8 bit
parents
d3e0e39d
2a6828e6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
249 additions
and
23 deletions
+249
-23
README.md
README.md
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+12
-0
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+1
-0
bitsandbytes/optim/lion.py
bitsandbytes/optim/lion.py
+87
-0
csrc/kernels.cu
csrc/kernels.cu
+75
-12
csrc/kernels.cuh
csrc/kernels.cuh
+4
-4
csrc/ops.cu
csrc/ops.cu
+33
-5
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+12
-0
requirements.txt
requirements.txt
+1
-0
tests/test_optim.py
tests/test_optim.py
+22
-1
No files found.
README.md
View file @
b0ec20c3
...
...
@@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
## Features
-
8-bit Matrix multiplication with mixed precision decomposition
-
LLM.int8() inference
-
8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
-
8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB
, Lion
(saves 75% memory)
-
Stable Embedding Layer: Improved stability through better initialization, and normalization
-
8-bit quantization: Quantile, Linear, and Dynamic quantization
-
Fast quantile estimation: Up to 100x faster than other algorithms
...
...
bitsandbytes/functional.py
View file @
b0ec20c3
...
...
@@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
,
)
str2optimizer32bit
[
"lion"
]
=
(
lib
.
clion32bit_g32
,
lib
.
clion32bit_g16
,
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
...
...
@@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
)
str2optimizer8bit
[
"lion"
]
=
(
lib
.
clion_static_8bit_g32
,
lib
.
clion_static_8bit_g16
,
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
...
...
@@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"lion"
]
=
(
lib
.
clion_8bit_blockwise_fp32
,
lib
.
clion_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"adagrad"
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
,
...
...
bitsandbytes/optim/__init__.py
View file @
b0ec20c3
...
...
@@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.optimizer
import
GlobalOptimManager
from
.rmsprop
import
RMSprop
,
RMSprop8bit
,
RMSprop32bit
from
.lion
import
Lion
,
Lion8bit
,
Lion32bit
from
.sgd
import
SGD
,
SGD8bit
,
SGD32bit
bitsandbytes/optim/lion.py
0 → 100644
View file @
b0ec20c3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Lion
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Lion8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
class
Lion32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
csrc/kernels.cu
View file @
b0ec20c3
...
...
@@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return
__int_as_float
(
old
);
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template
<
typename
T
>
__device__
int
sgn
(
T
val
)
{
return
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
}
template
<
int
STOCHASTIC
>
__device__
unsigned
char
dQuantize
(
float
*
smem_code
,
const
float
rand
,
float
x
)
{
...
...
@@ -743,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__
(
BLOCK_SIZE
/
NUM_VALS
,
1
)
__global__
void
kPreconditionOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
)
{
...
...
@@ -790,6 +798,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
float
)
g_vals
[
j
]);
// state update
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
case
LION
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
(
float
)
g_vals
[
j
]);
// state update
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
// state update
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
...
...
@@ -821,7 +832,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__
(
TH
,
1
)
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
)
{
...
...
@@ -890,6 +901,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
update_scale
*
(
-
lr
*
(
s1_vals
[
j
]));
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
]))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
((
float
)
g_vals
[
j
]));
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
...
...
@@ -1158,7 +1173,7 @@ __global__ void
__launch_bounds__
(
NUM_THREADS
,
2
)
kPreconditionOptimizerStatic8bit1State
(
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
__restrict__
const
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -1219,6 +1234,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if
(
unorm
!=
NULL
)
local_unorm
+=
s1_vals
[
j
]
*
s1_vals
[
j
];
break
;
case
LION
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1244,7 +1262,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizerStatic8bit1State
(
T
*
p
,
T
*
const
g
,
unsigned
char
*
state1
,
const
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
const
float
lr
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -1307,8 +1325,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
{
g_val
=
float
(
g_vals
[
j
]);
g_val
*=
gnorm_scale
;
if
(
weight_decay
>
0.0
f
)
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
if
(
weight_decay
>
0.0
f
)
{
switch
(
OPTIMIZER
)
{
case
MOMENTUM
:
case
RMSPROP
:
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
*
(
1.0
f
-
lr
*
weight_decay
);
break
;
}
}
s1_vals
[
j
]
=
smem_quantiles1
[
c1s
[
j
]]
*
max1
[
0
];
switch
(
OPTIMIZER
)
...
...
@@ -1321,6 +1350,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
-
lr
*
update_scale
*
(
s1_vals
[
j
]));
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_val
))));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
@@ -1649,10 +1682,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val
=
float
(
g_vals
[
j
]);
g_val
*=
gnorm_scale
;
if
(
!
skip_zeros
||
(
skip_zeros
&&
((
float
)
g_vals
[
j
]
!=
0.0
f
)))
{
if
(
weight_decay
>
0.0
f
)
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
if
(
!
skip_zeros
||
(
skip_zeros
&&
((
float
)
g_vals
[
j
]
!=
0.0
f
)))
{
if
(
weight_decay
>
0.0
f
)
{
switch
(
OPTIMIZER
)
{
case
MOMENTUM
:
case
ADAGRAD
:
case
RMSPROP
:
g_val
+=
((
float
)
p_vals
[
j
])
*
weight_decay
;
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
*
(
1.0
f
-
lr
*
weight_decay
);
break
;
}
}
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
...
...
@@ -1664,6 +1707,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
g_val
;
break
;
case
LION
:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals
[
j
]
=
lr
*
sgn
(((
float
)
s1_vals
[
j
])
*
beta1
+
((
1.0
f
-
beta1
)
*
g_val
));
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta2
+
((
1.0
f
-
beta2
)
*
g_val
);
break
;
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1701,6 +1749,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case
MOMENTUM
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
s1_vals
[
j
]);
break
;
case
LION
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
((
float
)
g_vals
[
j
]);
break
;
case
RMSPROP
:
g_val
=
g_vals
[
j
];
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
@@ -2692,24 +2743,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const float beta1, const float
beta2, const float
eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
float
)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
const float beta1, const float
beta2, const float
eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State
(
MOMENTUM
,
half
)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
half
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
float
)
MAKE_Optimizer32bit1State
(
LION
,
half
)
MAKE_Optimizer32bit1State
(
LION
,
float
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
float
)
...
...
@@ -2731,6 +2786,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
...
...
@@ -2742,11 +2798,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionStatic8bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionStatic8bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionStatic8bit1State
(
LION
,
half
)
MAKE_PreconditionStatic8bit1State
(
LION
,
float
)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
...
...
@@ -2758,6 +2817,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State
(
MOMENTUM
,
float
)
MAKE_optimizerStatic8bit1State
(
RMSPROP
,
half
)
MAKE_optimizerStatic8bit1State
(
RMSPROP
,
float
)
MAKE_optimizerStatic8bit1State
(
LION
,
half
)
MAKE_optimizerStatic8bit1State
(
LION
,
float
)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
...
...
@@ -2849,5 +2910,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
half
,
2048
,
8
)
csrc/kernels.cuh
View file @
b0ec20c3
...
...
@@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPreconditionOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
__global__
void
kOptimizer32bit1State
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
__global__
void
kPreconditionOptimizerStatic8bit1State
(
T
*
p
,
T
*
__restrict__
const
g
,
unsigned
char
*
__restrict__
const
state1
,
float
*
unorm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
@@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__
void
kOptimizerStatic8bit1State
(
T
*
p
,
T
*
const
g
,
unsigned
char
*
state1
,
const
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
int
step
,
const
float
lr
,
float
*
__restrict__
const
quantiles1
,
float
*
max1
,
float
*
new_max1
,
...
...
csrc/ops.cu
View file @
b0ec20c3
...
...
@@ -120,17 +120,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case
MOMENTUM
:
case
RMSPROP
:
case
ADAGRAD
:
if
(
max_unorm
>
0.0
f
)
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
LION
:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
if
(
max_unorm
>
0.0
f
)
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
break
;
}
}
...
...
@@ -164,12 +175,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
RMSPROP
:
case
ADAGRAD
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
step
,
lr
,
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
case
LION
:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
default:
break
;
}
...
...
@@ -198,6 +219,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case
MOMENTUM
:
case
RMSPROP
:
case
ADAGRAD
:
case
LION
:
num_blocks
=
n
/
BLOCKSIZE_1STATE
;
num_blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
num_blocks
:
num_blocks
+
1
;
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
num_blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
...
...
@@ -707,6 +729,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
MAKE_optimizer32bit
(
RMSPROP
,
float
)
MAKE_optimizer32bit
(
LION
,
half
)
MAKE_optimizer32bit
(
LION
,
float
)
MAKE_optimizer32bit
(
ADAGRAD
,
half
)
MAKE_optimizer32bit
(
ADAGRAD
,
float
)
...
...
@@ -726,6 +750,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit
(
MOMENTUM
,
float
)
MAKE_optimizerStatic8bit
(
RMSPROP
,
half
)
MAKE_optimizerStatic8bit
(
RMSPROP
,
float
)
MAKE_optimizerStatic8bit
(
LION
,
half
)
MAKE_optimizerStatic8bit
(
LION
,
float
)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
...
...
@@ -738,6 +764,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise
(
float
,
MOMENTUM
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
ADAGRAD
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
ADAGRAD
);
...
...
csrc/ops.cuh
View file @
b0ec20c3
...
...
@@ -70,6 +70,7 @@ typedef enum Optimizer_t
RMSPROP
=
2
,
LARS
=
3
,
ADAGRAD
=
4
,
LION
=
5
,
}
Optimizer_t
;
typedef
enum
Transform_t
...
...
csrc/pythonInterface.c
View file @
b0ec20c3
...
...
@@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
lion
,
LION
,
float
,
32
)
MAKE_FUNC32
(
lion
,
LION
,
half
,
16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
half
,
16
)
...
...
@@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC8
(
lion
,
LION
,
float
,
32
)
MAKE_FUNC8
(
lion
,
LION
,
half
,
16
)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...
...
@@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
half
,
16
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
float
,
32
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
...
@@ -161,6 +167,8 @@ extern "C"
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
half
,
16
)
MAKE_CFUNC32
(
lion
,
float
,
32
)
MAKE_CFUNC32
(
lion
,
half
,
16
)
MAKE_CFUNC32
(
adagrad
,
float
,
32
)
MAKE_CFUNC32
(
adagrad
,
half
,
16
)
...
...
@@ -183,6 +191,8 @@ extern "C"
MAKE_CFUNC8
(
momentum
,
half
,
16
)
MAKE_CFUNC8
(
rmsprop
,
float
,
32
)
MAKE_CFUNC8
(
rmsprop
,
half
,
16
)
MAKE_CFUNC8
(
lion
,
float
,
32
)
MAKE_CFUNC8
(
lion
,
half
,
16
)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...
...
@@ -196,6 +206,8 @@ extern "C"
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
half
,
16
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
float
,
32
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
...
requirements.txt
View file @
b0ec20c3
lion-pytorch
pytest
tests/test_optim.py
View file @
b0ec20c3
...
...
@@ -7,6 +7,8 @@ from itertools import product
from
os.path
import
join
import
pytest
from
lion_pytorch
import
Lion
import
torch
import
bitsandbytes
as
bnb
...
...
@@ -31,6 +33,7 @@ str2optimizers = {}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum_pytorch"
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
...
...
@@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
"lion"
]
=
(
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
...
@@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"lion8bit"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"momentum8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
...
@@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
...
...
@@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames
=
{}
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
...
...
@@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
]
str2statenames
[
"lion8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)
]
str2statenames
[
"lamb8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
...
...
@@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"lion8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"momentum8bit"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
]
...
...
@@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lion"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
...
...
@@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam8bit"
,
"lion8bit"
,
"momentum8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"lion8bit_blockwise"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment