Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
7dc198fe
Commit
7dc198fe
authored
Apr 17, 2023
by
Tim Dettmers
Browse files
Added 32-bit optimizer for bfloat16 gradients.
parent
b8ea2b41
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
65 additions
and
86 deletions
+65
-86
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+34
-55
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+7
-0
csrc/kernels.cu
csrc/kernels.cu
+5
-2
csrc/ops.cu
csrc/ops.cu
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+6
-4
tests/test_optim.py
tests/test_optim.py
+11
-24
No files found.
bitsandbytes/cextension.py
View file @
7dc198fe
...
...
@@ -23,7 +23,7 @@ try:
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues'''
)
lib
.
cadam
32
bit_
g
32
lib
.
cadam
_8
bit_
blockwise_fp
32
lib
.
get_context
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
...
...
bitsandbytes/functional.py
View file @
7dc198fe
...
...
@@ -28,7 +28,7 @@ name2qmap = {}
if
COMPILED_WITH_CUDA
:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{}
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g
fp
32
,
lib
.
cadam32bit_g
fp16
,
lib
.
cadam32bit_gbf
16
)
str2optimizer32bit
[
"momentum"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
...
...
@@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA:
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
)
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
)
str2optimizer32bit
[
"lamb"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
[
"adam"
]
=
(
...
...
@@ -998,53 +993,37 @@ def optimizer_update_32bit(
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
optimizer_name
not
in
str2optimizer32bit
:
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
0
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
len
(
str2optimizer32bit
[
optimizer_name
])
==
3
):
optim_func
=
str2optimizer32bit
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
prev_device
=
pre_call
(
g
.
device
)
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
post_call
(
prev_device
)
def
optimizer_update_8bit
(
...
...
@@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise(
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
...
...
@@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise(
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
prev_device
=
pre_call
(
g
.
device
)
optim
izer
_func
(
optim_func
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
...
...
bitsandbytes/nn/modules.py
View file @
7dc198fe
...
...
@@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter):
s
[
0
]
=
s
[
0
].
to
(
device
)
if
self
.
compress_statistics
:
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s
[
-
2
][
0
]
=
s
[
-
2
][
0
].
to
(
device
)
# offset
s
[
-
2
][
1
][
0
]
=
s
[
-
2
][
1
][
0
].
to
(
device
)
# nested quantiation state statitics
s
[
-
2
][
1
][
1
]
=
s
[
-
2
][
1
][
1
].
to
(
device
)
# nested quantiation codebook
...
...
csrc/kernels.cu
View file @
7dc198fe
...
...
@@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
half
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
float
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
half
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
__nv_bfloat16
)
template
__global__
void
kOptimizer32bit2State
<
float
,
ADAM
>(
float
*
g
,
float
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
__global__
void
kOptimizer32bit2State
<
half
,
ADAM
>(
half
*
g
,
half
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
__global__
void
kOptimizer32bit2State
<
float
,
ADAM
>(
float
*
g
,
float
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
template
__global__
void
kOptimizer32bit2State
<
__nv_b
float
16
,
ADAM
>(
__nv_b
float
16
*
g
,
__nv_b
float
16
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
...
...
csrc/ops.cu
View file @
7dc198fe
...
...
@@ -703,6 +703,7 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
MAKE_optimizer32bit
(
ADAM
,
half
)
MAKE_optimizer32bit
(
ADAM
,
float
)
MAKE_optimizer32bit
(
ADAM
,
__nv_bfloat16
)
MAKE_optimizer32bit
(
MOMENTUM
,
half
)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
...
...
csrc/pythonInterface.c
View file @
7dc198fe
...
...
@@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_FUNC32
(
adam
,
ADAM
,
float
,
32
)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
adam
,
ADAM
,
float
,
fp32
)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
fp16
)
MAKE_FUNC32
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
...
@@ -173,8 +174,9 @@ extern "C"
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32
(
adam
,
float
,
32
)
MAKE_CFUNC32
(
adam
,
half
,
16
)
MAKE_CFUNC32
(
adam
,
float
,
fp32
)
MAKE_CFUNC32
(
adam
,
half
,
fp16
)
MAKE_CFUNC32
(
adam
,
__nv_bfloat16
,
bf16
)
MAKE_CFUNC32
(
momentum
,
float
,
32
)
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
...
...
tests/test_optim.py
View file @
7dc198fe
...
...
@@ -44,10 +44,6 @@ str2optimizers["momentum"] = (
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"lars"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
...
@@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = (
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"lars8bit"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
...
...
@@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames
=
{}
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[
...
...
@@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
...
...
@@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
...
...
@@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-5
elif
gtype
==
torch
.
bfloat16
:
atol
,
rtol
=
1e-3
,
1e-2
else
:
atol
,
rtol
=
1e-4
,
1e-3
...
...
@@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol
=
rtol
,
)
if
gtype
=
=
torch
.
float
16
:
if
gtype
!
=
torch
.
float
32
:
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
p1
.
data
=
p1
.
data
.
half
(
).
float
()
p1
.
data
=
p1
.
data
.
to
(
p2
.
dtype
).
float
()
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(
),
p2
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
p2
.
dtype
),
p2
)
if
optim_name
in
[
"lars"
,
"lamb"
]:
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
...
...
@@ -246,7 +234,6 @@ optimizer_names = [
"momentum8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
...
...
@@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerr
=
err
/
torch
.
abs
(
p1
)
if
g
.
dtype
==
torch
.
bfloat16
:
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.001
5
assert
relerr
.
mean
()
<
0.001
6
else
:
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
assert
err
.
mean
()
<
0.0001
2
assert
relerr
.
mean
()
<
0.001
2
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment