Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b43edf56
Commit
b43edf56
authored
Jul 11, 2025
by
Egor Krivov
Browse files
Add interface for 8bit optimizer
parent
adc7fda7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
224 additions
and
74 deletions
+224
-74
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+61
-0
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+130
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+23
-72
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+3
-2
bitsandbytes/utils.py
bitsandbytes/utils.py
+7
-0
No files found.
bitsandbytes/_ops.py
View file @
b43edf56
...
...
@@ -348,3 +348,64 @@ if ipex_cpu or ipex_xpu:
)
->
torch
.
Tensor
:
torch
.
_check_is_size
(
blocksize
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_8bit_blockwise"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
_check
(
state1
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state1 must be uint8, got
{
state1
.
dtype
}
"
,
)
torch
.
_check
(
qmap1
.
dtype
==
absmax1
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap1 and absmax1 to be float32, got qmap1.dtype=
{
qmap1
.
dtype
}
, absmax1.dtype=
{
absmax1
.
dtype
}
"
,
)
if
state2
is
not
None
:
torch
.
_check
(
state2
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state2 must be uint8, got
{
state2
.
dtype
}
"
,
)
torch
.
_check
(
qmap2
.
dtype
==
absmax2
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap2 and absmax2 to be float32, got qmap2.dtype=
{
qmap2
.
dtype
}
, absmax2.dtype=
{
absmax2
.
dtype
}
"
,
)
bitsandbytes/backends/cuda/ops.py
View file @
b43edf56
...
...
@@ -538,3 +538,133 @@ def _gemv_4bit_impl(
ct
.
c_int32
(
blocksize
),
stream
,
)
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
def
_optimizer_update_8bit_blockwise_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
nsor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns
=
str2optimizer8bit_blockwise
.
get
(
optimizer_name
)
if
optimizer_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optimizer_fn
=
optimizer_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optimizer_fn
=
optimizer_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
:
optimizer_fn
=
optimizer_fns
[
2
]
else
:
raise
ValueError
(
f
"Unsupported gradient dtype:
{
g
.
dtype
}
. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)
with
_cuda_device_of
(
g
):
optimizer_fn
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
register_kernel
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"cuda"
)(
_optimizer_update_8bit_blockwise_impl
)
bitsandbytes/functional.py
View file @
b43edf56
...
...
@@ -82,39 +82,6 @@ str2optimizer8bit = {
),
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
class
GlobalPageManager
:
_instance
=
None
...
...
@@ -422,8 +389,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for
t
in
tensors
:
# NULL pointers and paged tensors are OK.
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
on_gpu
&=
t
.
is_cuda
gpu_ids
.
add
(
t
.
device
.
index
)
on_gpu
&=
t
.
device
.
type
!=
"cpu"
gpu_ids
.
add
(
(
t
.
device
.
type
,
t
.
device
.
index
)
)
if
not
on_gpu
:
raise
RuntimeError
(
...
...
@@ -1449,45 +1416,29 @@ def optimizer_update_8bit_blockwise(
)
->
None
:
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
torch
.
ops
.
bitsandbytes
.
optimizer_update_8bit_blockwise
(
optimizer_name
,
g
,
p
,
state1
,
state2
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
step
,
lr
,
qmap1
,
qmap2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
)
@
deprecated
(
"This function is deprecated and will be removed in a future release."
,
category
=
FutureWarning
)
...
...
bitsandbytes/optim/optimizer.py
View file @
b43edf56
...
...
@@ -10,6 +10,7 @@ from typing import Optional
import
torch
import
bitsandbytes.functional
as
F
from
bitsandbytes.utils
import
sync_gpu
class
MockArgs
:
...
...
@@ -289,11 +290,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
prefetch_state
(
p
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p
)
if
self
.
is_paged
:
# all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
loss
)
return
loss
...
...
bitsandbytes/utils.py
View file @
b43edf56
...
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
def
sync_gpu
(
t
:
torch
.
Tensor
):
if
t
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
t
.
device
.
type
==
"xpu"
:
torch
.
xpu
.
synchronize
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment