Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3b89a05e
Commit
3b89a05e
authored
Jul 14, 2025
by
Egor Krivov
Browse files
Add 32bit optimizer interface
parent
abf4a1e3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
69 deletions
+158
-69
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+43
-0
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+95
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+20
-69
No files found.
bitsandbytes/_ops.py
View file @
3b89a05e
...
...
@@ -350,6 +350,49 @@ if ipex_cpu or ipex_xpu:
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_32bit"
,
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, Tensor! unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_32bit"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()"
,
...
...
bitsandbytes/backends/cuda/ops.py
View file @
3b89a05e
...
...
@@ -540,6 +540,42 @@ def _gemv_4bit_impl(
)
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
...
...
@@ -574,6 +610,65 @@ str2optimizer8bit_blockwise = {
}
def
optimizer_update_32bit
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
optim_fns
=
str2optimizer32bit
.
get
(
optimizer_name
,
None
)
if
optim_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
optim_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
optim_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
optim_fns
)
==
3
:
optim_func
=
optim_fns
[
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
def
_optimizer_update_8bit_blockwise_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
...
...
bitsandbytes/functional.py
View file @
3b89a05e
...
...
@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap
=
{}
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit
=
{
"adam"
:
(
lib
.
cadam_static_8bit_grad_32
,
...
...
@@ -1219,41 +1184,27 @@ def optimizer_update_32bit(
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
str2optimizer32bit
[
optimizer_name
])
==
3
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
torch
.
ops
.
bitsandbytes
.
optimizer_update_32bit
(
optimizer_name
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
)
@
deprecated
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment