Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cb4c3c8c
Commit
cb4c3c8c
authored
Mar 09, 2023
by
Phil Wang
Browse files
do a bunch of typical bookkeeping before getting to main lion logic
parent
d43ea972
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
58 additions
and
4 deletions
+58
-4
README.md
README.md
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+12
-0
bitsandbytes/optim/lion.py
bitsandbytes/optim/lion.py
+3
-3
csrc/kernels.cu
csrc/kernels.cu
+16
-0
csrc/ops.cu
csrc/ops.cu
+9
-0
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+12
-0
tests/test_optim.py
tests/test_optim.py
+4
-0
No files found.
README.md
View file @
cb4c3c8c
...
...
@@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
## Features
-
8-bit Matrix multiplication with mixed precision decomposition
-
LLM.int8() inference
-
8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
-
8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB
, Lion
(saves 75% memory)
-
Stable Embedding Layer: Improved stability through better initialization, and normalization
-
8-bit quantization: Quantile, Linear, and Dynamic quantization
-
Fast quantile estimation: Up to 100x faster than other algorithms
...
...
bitsandbytes/functional.py
View file @
cb4c3c8c
...
...
@@ -35,6 +35,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop32bit_g32
,
lib
.
crmsprop32bit_g16
,
)
str2optimizer32bit
[
"lion"
]
=
(
lib
.
clion32bit_g32
,
lib
.
clion32bit_g16
,
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
...
...
@@ -58,6 +62,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_static_8bit_g32
,
lib
.
crmsprop_static_8bit_g16
,
)
str2optimizer8bit
[
"lion"
]
=
(
lib
.
clion_static_8bit_g32
,
lib
.
clion_static_8bit_g16
,
)
str2optimizer8bit
[
"lamb"
]
=
(
lib
.
cadam_static_8bit_g32
,
lib
.
cadam_static_8bit_g16
,
...
...
@@ -80,6 +88,10 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"lion"
]
=
(
lib
.
clion_8bit_blockwise_fp32
,
lib
.
clion_8bit_blockwise_fp16
,
)
str2optimizer8bit_blockwise
[
"adagrad"
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
,
...
...
bitsandbytes/optim/lion.py
View file @
cb4c3c8c
...
...
@@ -19,7 +19,7 @@ class Lion(Optimizer1State):
block_wise
=
True
,
):
super
().
__init__
(
"
rmsprop
"
,
"
lion
"
,
params
,
lr
,
betas
,
...
...
@@ -46,7 +46,7 @@ class Lion8bit(Optimizer1State):
block_wise
=
True
,
):
super
().
__init__
(
"
rmsprop
"
,
"
lion
"
,
params
,
lr
,
betas
,
...
...
@@ -73,7 +73,7 @@ class Lion32bit(Optimizer1State):
block_wise
=
True
,
):
super
().
__init__
(
"
rmsprop
"
,
"
lion
"
,
params
,
lr
,
betas
,
...
...
csrc/kernels.cu
View file @
cb4c3c8c
...
...
@@ -790,6 +790,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
float
)
g_vals
[
j
]);
// state update
s1_vals
[
j
]
=
s1_vals
[
j
]
*
s1_vals
[
j
];
// update norm
break
;
case
LION
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
// state update
s1_vals
[
j
]
=
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
(
s1_vals
[
j
])
+
eps
);
// update value
...
...
@@ -890,6 +891,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
update_scale
*
(
-
lr
*
(
s1_vals
[
j
]));
break
;
case
LION
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
((
float
)
g_vals
[
j
])
*
((
float
)
g_vals
[
j
]));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
update_scale
*
(
lr
*
__fdividef
((
float
)
g_vals
[
j
],
sqrtf
((
float
)
s1_vals
[
j
])
+
eps
));
...
...
@@ -1219,6 +1221,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if
(
unorm
!=
NULL
)
local_unorm
+=
s1_vals
[
j
]
*
s1_vals
[
j
];
break
;
case
LION
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1321,6 +1324,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
+
(
-
lr
*
update_scale
*
(
s1_vals
[
j
]));
break
;
case
LION
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
(
lr
*
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
@@ -1664,6 +1668,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals
[
j
]
=
(
s1_vals
[
j
]
*
beta1
)
+
g_val
;
break
;
case
LION
:
case
RMSPROP
:
s1_vals
[
j
]
=
s1_vals
[
j
]
*
beta1
+
((
1.0
f
-
beta1
)
*
(
g_val
*
g_val
));
break
;
...
...
@@ -1701,6 +1706,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case
MOMENTUM
:
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
s1_vals
[
j
]);
break
;
case
LION
:
case
RMSPROP
:
g_val
=
g_vals
[
j
];
p_vals
[
j
]
=
((
float
)
p_vals
[
j
])
-
lr
*
(
__fdividef
(
g_val
,
sqrtf
(
s1_vals
[
j
])
+
eps
));
...
...
@@ -2699,6 +2705,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
float
)
...
...
@@ -2710,6 +2718,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State
(
MOMENTUM
,
float
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
half
)
MAKE_Optimizer32bit1State
(
RMSPROP
,
float
)
MAKE_Optimizer32bit1State
(
LION
,
half
)
MAKE_Optimizer32bit1State
(
LION
,
float
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
float
)
...
...
@@ -2742,6 +2752,8 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State
(
MOMENTUM
,
float
)
MAKE_PreconditionStatic8bit1State
(
RMSPROP
,
half
)
MAKE_PreconditionStatic8bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionStatic8bit1State
(
LION
,
half
)
MAKE_PreconditionStatic8bit1State
(
LION
,
float
)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
...
...
@@ -2758,6 +2770,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State
(
MOMENTUM
,
float
)
MAKE_optimizerStatic8bit1State
(
RMSPROP
,
half
)
MAKE_optimizerStatic8bit1State
(
RMSPROP
,
float
)
MAKE_optimizerStatic8bit1State
(
LION
,
half
)
MAKE_optimizerStatic8bit1State
(
LION
,
float
)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
...
...
@@ -2849,5 +2863,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise
(
MOMENTUM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
half
,
2048
,
8
)
csrc/ops.cu
View file @
cb4c3c8c
...
...
@@ -120,6 +120,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case
MOMENTUM
:
case
RMSPROP
:
case
ADAGRAD
:
case
LION
:
if
(
max_unorm
>
0.0
f
)
{
...
...
@@ -163,6 +164,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
MOMENTUM
:
case
RMSPROP
:
case
ADAGRAD
:
case
LION
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
...
...
@@ -198,6 +200,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case
MOMENTUM
:
case
RMSPROP
:
case
ADAGRAD
:
case
LION
:
num_blocks
=
n
/
BLOCKSIZE_1STATE
;
num_blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
num_blocks
:
num_blocks
+
1
;
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
num_blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
...
...
@@ -707,6 +710,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
MAKE_optimizer32bit
(
RMSPROP
,
float
)
MAKE_optimizer32bit
(
LION
,
half
)
MAKE_optimizer32bit
(
LION
,
float
)
MAKE_optimizer32bit
(
ADAGRAD
,
half
)
MAKE_optimizer32bit
(
ADAGRAD
,
float
)
...
...
@@ -726,6 +731,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit
(
MOMENTUM
,
float
)
MAKE_optimizerStatic8bit
(
RMSPROP
,
half
)
MAKE_optimizerStatic8bit
(
RMSPROP
,
float
)
MAKE_optimizerStatic8bit
(
LION
,
half
)
MAKE_optimizerStatic8bit
(
LION
,
float
)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
...
...
@@ -738,6 +745,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise
(
float
,
MOMENTUM
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
ADAGRAD
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
ADAGRAD
);
...
...
csrc/ops.cuh
View file @
cb4c3c8c
...
...
@@ -70,6 +70,7 @@ typedef enum Optimizer_t
RMSPROP
=
2
,
LARS
=
3
,
ADAGRAD
=
4
,
LION
=
5
,
}
Optimizer_t
;
typedef
enum
Transform_t
...
...
csrc/pythonInterface.c
View file @
cb4c3c8c
...
...
@@ -33,6 +33,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
lion
,
LION
,
float
,
32
)
MAKE_FUNC32
(
lion
,
LION
,
half
,
16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
half
,
16
)
...
...
@@ -55,6 +57,8 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
MAKE_FUNC8
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC8
(
lion
,
LION
,
float
,
32
)
MAKE_FUNC8
(
lion
,
LION
,
half
,
16
)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...
...
@@ -68,6 +72,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
half
,
16
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
float
,
32
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
...
@@ -161,6 +167,8 @@ extern "C"
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
half
,
16
)
MAKE_CFUNC32
(
lion
,
float
,
32
)
MAKE_CFUNC32
(
lion
,
half
,
16
)
MAKE_CFUNC32
(
adagrad
,
float
,
32
)
MAKE_CFUNC32
(
adagrad
,
half
,
16
)
...
...
@@ -183,6 +191,8 @@ extern "C"
MAKE_CFUNC8
(
momentum
,
half
,
16
)
MAKE_CFUNC8
(
rmsprop
,
float
,
32
)
MAKE_CFUNC8
(
rmsprop
,
half
,
16
)
MAKE_CFUNC8
(
lion
,
float
,
32
)
MAKE_CFUNC8
(
lion
,
half
,
16
)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
...
...
@@ -196,6 +206,8 @@ extern "C"
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
half
,
16
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
float
,
32
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
...
tests/test_optim.py
View file @
cb4c3c8c
...
...
@@ -50,6 +50,10 @@ str2optimizers["rmsprop"] = (
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"adam8bit"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment