Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b43edf56
"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "7faa776659b1b11f130ff9c20a0b97b17f241365"
Commit
b43edf56
authored
Jul 11, 2025
by
Egor Krivov
Browse files
Add interface for 8bit optimizer
parent
adc7fda7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
224 additions
and
74 deletions
+224
-74
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+61
-0
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+130
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+23
-72
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+3
-2
bitsandbytes/utils.py
bitsandbytes/utils.py
+7
-0
No files found.
bitsandbytes/_ops.py
View file @
b43edf56
...
@@ -348,3 +348,64 @@ if ipex_cpu or ipex_xpu:
...
@@ -348,3 +348,64 @@ if ipex_cpu or ipex_xpu:
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
_check_is_size
(
blocksize
)
torch
.
_check_is_size
(
blocksize
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_8bit_blockwise"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
_check
(
state1
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state1 must be uint8, got
{
state1
.
dtype
}
"
,
)
torch
.
_check
(
qmap1
.
dtype
==
absmax1
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap1 and absmax1 to be float32, got qmap1.dtype=
{
qmap1
.
dtype
}
, absmax1.dtype=
{
absmax1
.
dtype
}
"
,
)
if
state2
is
not
None
:
torch
.
_check
(
state2
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state2 must be uint8, got
{
state2
.
dtype
}
"
,
)
torch
.
_check
(
qmap2
.
dtype
==
absmax2
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap2 and absmax2 to be float32, got qmap2.dtype=
{
qmap2
.
dtype
}
, absmax2.dtype=
{
absmax2
.
dtype
}
"
,
)
bitsandbytes/backends/cuda/ops.py
View file @
b43edf56
...
@@ -538,3 +538,133 @@ def _gemv_4bit_impl(
...
@@ -538,3 +538,133 @@ def _gemv_4bit_impl(
ct
.
c_int32
(
blocksize
),
ct
.
c_int32
(
blocksize
),
stream
,
stream
,
)
)
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
def
_optimizer_update_8bit_blockwise_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
nsor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
=
0.0
,
gnorm_scale
:
float
=
1.0
,
skip_zeros
=
False
,
)
->
None
:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns
=
str2optimizer8bit_blockwise
.
get
(
optimizer_name
)
if
optimizer_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optimizer_fn
=
optimizer_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optimizer_fn
=
optimizer_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
:
optimizer_fn
=
optimizer_fns
[
2
]
else
:
raise
ValueError
(
f
"Unsupported gradient dtype:
{
g
.
dtype
}
. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)
with
_cuda_device_of
(
g
):
optimizer_fn
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
register_kernel
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"cuda"
)(
_optimizer_update_8bit_blockwise_impl
)
bitsandbytes/functional.py
View file @
b43edf56
...
@@ -82,39 +82,6 @@ str2optimizer8bit = {
...
@@ -82,39 +82,6 @@ str2optimizer8bit = {
),
),
}
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
class
GlobalPageManager
:
class
GlobalPageManager
:
_instance
=
None
_instance
=
None
...
@@ -422,8 +389,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
...
@@ -422,8 +389,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for
t
in
tensors
:
for
t
in
tensors
:
# NULL pointers and paged tensors are OK.
# NULL pointers and paged tensors are OK.
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
on_gpu
&=
t
.
is_cuda
on_gpu
&=
t
.
device
.
type
!=
"cpu"
gpu_ids
.
add
(
t
.
device
.
index
)
gpu_ids
.
add
(
(
t
.
device
.
type
,
t
.
device
.
index
)
)
if
not
on_gpu
:
if
not
on_gpu
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -1449,44 +1416,28 @@ def optimizer_update_8bit_blockwise(
...
@@ -1449,44 +1416,28 @@ def optimizer_update_8bit_blockwise(
)
->
None
:
)
->
None
:
optim_func
=
None
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
with
_cuda_device_of
(
g
):
torch
.
ops
.
bitsandbytes
.
optimizer_update_8bit_blockwise
(
optim_func
(
optimizer_name
,
get_ptr
(
p
),
g
,
get_ptr
(
g
),
p
,
get_ptr
(
state1
),
state1
,
get_ptr
(
state2
),
state2
,
ct
.
c_float
(
beta1
),
beta1
,
ct
.
c_float
(
beta2
),
beta2
,
ct
.
c_float
(
beta3
),
beta3
,
ct
.
c_float
(
alpha
),
alpha
,
ct
.
c_float
(
eps
),
eps
,
ct
.
c_int32
(
step
),
step
,
ct
.
c_float
(
lr
),
lr
,
get_ptr
(
qmap1
),
qmap1
,
get_ptr
(
qmap2
),
qmap2
,
get_ptr
(
absmax1
),
absmax1
,
get_ptr
(
absmax2
),
absmax2
,
ct
.
c_float
(
weight_decay
),
weight_decay
,
ct
.
c_float
(
gnorm_scale
),
gnorm_scale
,
ct
.
c_bool
(
skip_zeros
),
skip_zeros
,
ct
.
c_int32
(
g
.
numel
()),
)
)
...
...
bitsandbytes/optim/optimizer.py
View file @
b43edf56
...
@@ -10,6 +10,7 @@ from typing import Optional
...
@@ -10,6 +10,7 @@ from typing import Optional
import
torch
import
torch
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
bitsandbytes.utils
import
sync_gpu
class
MockArgs
:
class
MockArgs
:
...
@@ -289,11 +290,11 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -289,11 +290,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
prefetch_state
(
p
)
self
.
prefetch_state
(
p
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p
)
if
self
.
is_paged
:
if
self
.
is_paged
:
# all paged operations are asynchronous, we need
# all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
# to sync to make sure all tensors are in the right state
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
loss
)
return
loss
return
loss
...
...
bitsandbytes/utils.py
View file @
b43edf56
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
def
sync_gpu
(
t
:
torch
.
Tensor
):
if
t
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
t
.
device
.
type
==
"xpu"
:
torch
.
xpu
.
synchronize
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment