Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1b8772a8
Commit
1b8772a8
authored
May 23, 2023
by
Tim Dettmers
Browse files
Added PagedLion and bf16 Lion.
parent
2bce175d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
46 additions
and
97 deletions
+46
-97
bitsandbytes/functional.py
bitsandbytes/functional.py
+2
-4
bitsandbytes/optim/__init__.py
bitsandbytes/optim/__init__.py
+1
-1
bitsandbytes/optim/lion.py
bitsandbytes/optim/lion.py
+19
-76
csrc/kernels.cu
csrc/kernels.cu
+3
-0
csrc/ops.cu
csrc/ops.cu
+2
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+8
-4
tests/test_optim.py
tests/test_optim.py
+11
-12
No files found.
bitsandbytes/functional.py
View file @
1b8772a8
...
...
@@ -37,10 +37,7 @@ if COMPILED_WITH_CUDA:
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
)
str2optimizer32bit
[
"lion"
]
=
(
lib
.
clion32bit_grad_32
,
lib
.
clion32bit_grad_16
,
)
str2optimizer32bit
[
"lion"
]
=
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
)
str2optimizer32bit
[
"adagrad"
]
=
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
...
...
@@ -89,6 +86,7 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise
[
"lion"
]
=
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
)
str2optimizer8bit_blockwise
[
"adagrad"
]
=
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
...
...
bitsandbytes/optim/__init__.py
View file @
1b8772a8
...
...
@@ -12,5 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from
.lars
import
LARS
,
LARS8bit
,
LARS32bit
,
PytorchLARS
from
.optimizer
import
GlobalOptimManager
from
.rmsprop
import
RMSprop
,
RMSprop8bit
,
RMSprop32bit
from
.lion
import
Lion
,
Lion8bit
,
Lion32bit
from
.lion
import
Lion
,
Lion8bit
,
Lion32bit
,
PagedLion
,
PagedLion8bit
,
PagedLion32bit
from
.sgd
import
SGD
,
SGD8bit
,
SGD32bit
bitsandbytes/optim/lion.py
View file @
1b8772a8
...
...
@@ -4,84 +4,27 @@
# LICENSE file in the root directory of this source tree.
from
bitsandbytes.optim.optimizer
import
Optimizer1State
class
Lion
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
is_paged
=
False
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
is_paged
)
class
Lion8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
is_paged
=
False
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
is_paged
)
class
Lion32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
)
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
,
is_paged
=
False
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
is_paged
)
class
PagedLion
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
optim_bits
=
32
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
optim_bits
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
True
)
class
PagedLion8bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
8
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
True
)
class
PagedLion32bit
(
Optimizer1State
):
def
__init__
(
self
,
params
,
lr
=
1e-4
,
betas
=
(
0.9
,
0.99
),
weight_decay
=
0
,
args
=
None
,
min_8bit_size
=
4096
,
percentile_clipping
=
100
,
block_wise
=
True
):
super
().
__init__
(
"lion"
,
params
,
lr
,
betas
,
0.
,
weight_decay
,
32
,
args
,
min_8bit_size
,
percentile_clipping
,
block_wise
,
is_paged
=
True
)
csrc/kernels.cu
View file @
1b8772a8
...
...
@@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State
(
RMSPROP
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
float
)
MAKE_PreconditionOptimizer32bit1State
(
LION
,
__nv_bfloat16
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_PreconditionOptimizer32bit1State
(
ADAGRAD
,
float
)
...
...
@@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State
(
RMSPROP
,
float
)
MAKE_Optimizer32bit1State
(
LION
,
half
)
MAKE_Optimizer32bit1State
(
LION
,
float
)
MAKE_Optimizer32bit1State
(
LION
,
__nv_bfloat16
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
half
)
MAKE_Optimizer32bit1State
(
ADAGRAD
,
float
)
...
...
@@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise
(
RMSPROP
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
LION
,
__nv_bfloat16
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit1StateBlockwise
(
ADAGRAD
,
half
,
2048
,
8
)
csrc/ops.cu
View file @
1b8772a8
...
...
@@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit
(
RMSPROP
,
float
)
MAKE_optimizer32bit
(
LION
,
half
)
MAKE_optimizer32bit
(
LION
,
float
)
MAKE_optimizer32bit
(
LION
,
__nv_bfloat16
)
MAKE_optimizer32bit
(
ADAGRAD
,
half
)
MAKE_optimizer32bit
(
ADAGRAD
,
float
)
...
...
@@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise
(
float
,
RMSPROP
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
__nv_bfloat16
,
LION
);
MAKE_optimizerStatic8bitBlockwise
(
half
,
ADAGRAD
);
MAKE_optimizerStatic8bitBlockwise
(
float
,
ADAGRAD
);
...
...
csrc/pythonInterface.c
View file @
1b8772a8
...
...
@@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
lion
,
LION
,
float
,
32
)
MAKE_FUNC32
(
lion
,
LION
,
half
,
16
)
MAKE_FUNC32
(
lion
,
LION
,
float
,
fp32
)
MAKE_FUNC32
(
lion
,
LION
,
half
,
fp16
)
MAKE_FUNC32
(
lion
,
LION
,
__nv_bfloat16
,
bf16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
half
,
16
)
...
...
@@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
half
,
fp16
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
float
,
fp32
)
MAKE_BLOCKWISE8
(
lion
,
LION
,
__nv_bfloat16
,
bf16
)
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
...
...
@@ -201,8 +203,9 @@ extern "C"
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
half
,
16
)
MAKE_CFUNC32
(
lion
,
float
,
32
)
MAKE_CFUNC32
(
lion
,
half
,
16
)
MAKE_CFUNC32
(
lion
,
float
,
fp32
)
MAKE_CFUNC32
(
lion
,
half
,
fp16
)
MAKE_CFUNC32
(
lion
,
__nv_bfloat16
,
bf16
)
MAKE_CFUNC32
(
adagrad
,
float
,
32
)
MAKE_CFUNC32
(
adagrad
,
half
,
16
)
...
...
@@ -245,6 +248,7 @@ extern "C"
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
half
,
fp16
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
float
,
fp32
)
MAKE_CBLOCKWISE8
(
lion
,
LION
,
__nv_bfloat16
,
bf16
)
void
cpercentile_clipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g32
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
...
...
tests/test_optim.py
View file @
1b8772a8
...
...
@@ -19,11 +19,11 @@ import bitsandbytes.functional as F
k
=
20
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
error_count
=
(
idx
==
0
).
sum
().
item
()
if
error_count
>
max_error_count
:
print
(
f
"Too many values not close: assert
{
error_count
}
<
{
max_error_count
}
"
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
def
get_temp_dir
():
...
...
@@ -35,13 +35,8 @@ def get_temp_dir():
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
str2bf16support
=
{}
str2bf16support
[
'adam8bit_blockwise'
]
=
True
str2optimizers
=
{}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum_pytorch"
]
=
(
None
,
...
...
@@ -51,8 +46,8 @@ str2optimizers["momentum_pytorch"] = (
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"paged_adamw"
]
=
(
torch
.
optim
.
AdamW
,
bnb
.
optim
.
PagedAdamW
)
str2optimizers
[
"paged_adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
PagedAdam
)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
"lion"
]
=
(
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"paged_lion"
]
=
(
Lion
,
bnb
.
optim
.
PagedLion
)
str2optimizers
[
"momentum"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
...
@@ -76,6 +71,7 @@ str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
...
...
@@ -90,6 +86,7 @@ str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames
[
"paged_adamw"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"paged_adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"paged_lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
...
...
@@ -104,15 +101,17 @@ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
'paged_adamw'
,
'paged_adam'
,
'lion'
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
'paged_adamw'
,
'paged_adam'
,
'lion'
,
'paged_lion'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
'momentum'
,
'rmsprop'
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
...
@@ -254,7 +253,7 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
str2bf16support
:
return
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
'adam8bit_blockwise'
,
'lion8bit_blockwise'
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
...
@@ -485,7 +484,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
"adam8bit_blockwise"
,
'paged_adam8bit_blockwise'
,
'paged_adamw8bit_blockwise'
]
optimizer_names
=
[
"adam8bit_blockwise"
,
'paged_adam8bit_blockwise'
,
'paged_adamw8bit_blockwise'
,
'paged_lion8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment