Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
941681da
Unverified
Commit
941681da
authored
Jul 14, 2025
by
Matthew Douglas
Committed by
GitHub
Jul 14, 2025
Browse files
Merge pull request #1706 from Egor-Krivov/egor/8bit_int
Add kernel registration for 8bit and 32bit optimizers
parents
adc7fda7
0f6fe6bf
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
417 additions
and
172 deletions
+417
-172
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+104
-0
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+226
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+43
-141
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+5
-3
bitsandbytes/utils.py
bitsandbytes/utils.py
+7
-0
tests/helpers.py
tests/helpers.py
+3
-3
tests/test_optim.py
tests/test_optim.py
+29
-25
No files found.
bitsandbytes/_ops.py
View file @
941681da
...
@@ -348,3 +348,107 @@ if ipex_cpu or ipex_xpu:
...
@@ -348,3 +348,107 @@ if ipex_cpu or ipex_xpu:
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
torch
.
_check_is_size
(
blocksize
)
torch
.
_check_is_size
(
blocksize
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_32bit"
,
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_32bit"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_8bit_blockwise"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
_check
(
state1
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state1 must be uint8, got
{
state1
.
dtype
}
"
,
)
torch
.
_check
(
qmap1
.
dtype
==
absmax1
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap1 and absmax1 to be float32, got qmap1.dtype=
{
qmap1
.
dtype
}
, absmax1.dtype=
{
absmax1
.
dtype
}
"
,
)
if
state2
is
not
None
:
torch
.
_check
(
state2
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state2 must be uint8, got
{
state2
.
dtype
}
"
,
)
torch
.
_check
(
qmap2
.
dtype
==
absmax2
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap2 and absmax2 to be float32, got qmap2.dtype=
{
qmap2
.
dtype
}
, absmax2.dtype=
{
absmax2
.
dtype
}
"
,
)
bitsandbytes/backends/cuda/ops.py
View file @
941681da
...
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
...
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
ct
.
c_int32
(
blocksize
),
ct
.
c_int32
(
blocksize
),
stream
,
stream
,
)
)
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
def
_optimizer_update_32bit_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
optim_fns
=
str2optimizer32bit
.
get
(
optimizer_name
,
None
)
if
optim_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
optim_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
optim_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
optim_fns
)
==
3
:
optim_func
=
optim_fns
[
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
def
_optimizer_update_8bit_blockwise_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns
=
str2optimizer8bit_blockwise
.
get
(
optimizer_name
)
if
optimizer_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optimizer_fn
=
optimizer_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optimizer_fn
=
optimizer_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
:
optimizer_fn
=
optimizer_fns
[
2
]
else
:
raise
ValueError
(
f
"Unsupported gradient dtype:
{
g
.
dtype
}
. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)
with
_cuda_device_of
(
g
):
optimizer_fn
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
register_kernel
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"cuda"
)(
_optimizer_update_8bit_blockwise_impl
)
register_kernel
(
"bitsandbytes::optimizer_update_32bit"
,
"cuda"
)(
_optimizer_update_32bit_impl
)
bitsandbytes/functional.py
View file @
941681da
...
@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
...
@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap
=
{}
name2qmap
=
{}
"""C FUNCTIONS FOR OPTIMIZERS"""
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit
=
{
str2optimizer8bit
=
{
"adam"
:
(
"adam"
:
(
lib
.
cadam_static_8bit_grad_32
,
lib
.
cadam_static_8bit_grad_32
,
...
@@ -82,39 +47,6 @@ str2optimizer8bit = {
...
@@ -82,39 +47,6 @@ str2optimizer8bit = {
),
),
}
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
class
GlobalPageManager
:
class
GlobalPageManager
:
_instance
=
None
_instance
=
None
...
@@ -422,8 +354,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
...
@@ -422,8 +354,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for
t
in
tensors
:
for
t
in
tensors
:
# NULL pointers and paged tensors are OK.
# NULL pointers and paged tensors are OK.
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
on_gpu
&=
t
.
is_cuda
on_gpu
&=
t
.
device
.
type
!=
"cpu"
gpu_ids
.
add
(
t
.
device
.
index
)
gpu_ids
.
add
(
(
t
.
device
.
type
,
t
.
device
.
index
)
)
if
not
on_gpu
:
if
not
on_gpu
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -1252,40 +1184,26 @@ def optimizer_update_32bit(
...
@@ -1252,40 +1184,26 @@ def optimizer_update_32bit(
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
str2optimizer32bit
[
optimizer_name
])
==
3
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
torch
.
ops
.
bitsandbytes
.
optimizer_update_32bit
(
with
_cuda_device_of
(
g
):
optimizer_name
,
optim_func
(
g
,
get_ptr
(
g
),
p
,
get_ptr
(
p
),
state1
,
get_ptr
(
state1
),
state2
,
get_ptr
(
state2
),
unorm_vec
,
get_ptr
(
unorm_vec
),
max_unorm
,
ct
.
c_float
(
max_unorm
),
param_norm
,
ct
.
c_float
(
param_norm
),
beta1
,
ct
.
c_float
(
beta1
),
beta2
,
ct
.
c_float
(
beta2
),
beta3
,
ct
.
c_float
(
beta3
),
alpha
,
ct
.
c_float
(
alpha
),
eps
,
ct
.
c_float
(
eps
),
weight_decay
,
ct
.
c_float
(
weight_decay
),
step
,
ct
.
c_int32
(
step
),
lr
,
ct
.
c_float
(
lr
),
gnorm_scale
,
ct
.
c_float
(
gnorm_scale
),
skip_zeros
,
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
)
...
@@ -1449,44 +1367,28 @@ def optimizer_update_8bit_blockwise(
...
@@ -1449,44 +1367,28 @@ def optimizer_update_8bit_blockwise(
)
->
None
:
)
->
None
:
optim_func
=
None
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
with
_cuda_device_of
(
g
):
torch
.
ops
.
bitsandbytes
.
optimizer_update_8bit_blockwise
(
optim_func
(
optimizer_name
,
get_ptr
(
p
),
g
,
get_ptr
(
g
),
p
,
get_ptr
(
state1
),
state1
,
get_ptr
(
state2
),
state2
,
ct
.
c_float
(
beta1
),
beta1
,
ct
.
c_float
(
beta2
),
beta2
,
ct
.
c_float
(
beta3
),
beta3
,
ct
.
c_float
(
alpha
),
alpha
,
ct
.
c_float
(
eps
),
eps
,
ct
.
c_int32
(
step
),
step
,
ct
.
c_float
(
lr
),
lr
,
get_ptr
(
qmap1
),
qmap1
,
get_ptr
(
qmap2
),
qmap2
,
get_ptr
(
absmax1
),
absmax1
,
get_ptr
(
absmax2
),
absmax2
,
ct
.
c_float
(
weight_decay
),
weight_decay
,
ct
.
c_float
(
gnorm_scale
),
gnorm_scale
,
ct
.
c_bool
(
skip_zeros
),
skip_zeros
,
ct
.
c_int32
(
g
.
numel
()),
)
)
...
...
bitsandbytes/optim/optimizer.py
View file @
941681da
...
@@ -10,6 +10,7 @@ from typing import Optional
...
@@ -10,6 +10,7 @@ from typing import Optional
import
torch
import
torch
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
bitsandbytes.utils
import
sync_gpu
class
MockArgs
:
class
MockArgs
:
...
@@ -279,6 +280,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -279,6 +280,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
initialized
=
True
self
.
initialized
=
True
# if self.is_paged: self.page_mng.prefetch_all()
# if self.is_paged: self.page_mng.prefetch_all()
p
=
None
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
...
@@ -289,11 +291,11 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -289,11 +291,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
prefetch_state
(
p
)
self
.
prefetch_state
(
p
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p
)
if
self
.
is_paged
:
if
self
.
is_paged
and
p
is
not
None
:
# all paged operations are asynchronous, we need
# all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
# to sync to make sure all tensors are in the right state
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p
)
return
loss
return
loss
...
...
bitsandbytes/utils.py
View file @
941681da
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
def
sync_gpu
(
t
:
torch
.
Tensor
):
if
t
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
t
.
device
.
type
==
"xpu"
:
torch
.
xpu
.
synchronize
()
tests/helpers.py
View file @
941681da
...
@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
...
@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
@
functools
.
cache
@
functools
.
cache
def
get_available_devices
():
def
get_available_devices
(
no_cpu
=
False
):
if
"BNB_TEST_DEVICE"
in
os
.
environ
:
if
"BNB_TEST_DEVICE"
in
os
.
environ
:
# If the environment variable is set, use it directly.
# If the environment variable is set, use it directly.
return
[
os
.
environ
[
"BNB_TEST_DEVICE"
]]
return
[
d
for
d
in
os
.
environ
[
"BNB_TEST_DEVICE"
]
if
d
.
lower
()
!=
"cpu"
]
devices
=
[]
if
HIP_ENVIRONMENT
else
[
"cpu"
]
devices
=
[]
if
HIP_ENVIRONMENT
else
[
"cpu"
]
if
not
no_cpu
else
[]
if
hasattr
(
torch
,
"accelerator"
):
if
hasattr
(
torch
,
"accelerator"
):
# PyTorch 2.6+ - determine accelerator using agnostic API.
# PyTorch 2.6+ - determine accelerator using agnostic API.
...
...
tests/test_optim.py
View file @
941681da
...
@@ -11,7 +11,8 @@ import torch
...
@@ -11,7 +11,8 @@ import torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
tests.helpers
import
describe_dtype
,
id_formatter
from
bitsandbytes.utils
import
sync_gpu
from
tests.helpers
import
describe_dtype
,
get_available_devices
,
id_formatter
# import apex
# import apex
...
@@ -168,7 +169,8 @@ optimizer_names_32bit = [
...
@@ -168,7 +169,8 @@ optimizer_names_32bit = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
def
test_optimizer32bit
(
requires_cuda
,
dim1
,
dim2
,
gtype
,
optim_name
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
),
ids
=
id_formatter
(
"device"
))
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
if
optim_name
.
startswith
(
"paged_"
)
and
sys
.
platform
==
"win32"
:
if
optim_name
.
startswith
(
"paged_"
)
and
sys
.
platform
==
"win32"
:
pytest
.
skip
(
"Paged optimizers can have issues on Windows."
)
pytest
.
skip
(
"Paged optimizers can have issues on Windows."
)
...
@@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
pytest
.
skip
()
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
p1
=
p1
.
float
()
...
@@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
k
):
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
...
@@ -201,14 +203,14 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -201,14 +203,14 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
torch_optimizer
.
state
[
p1
][
name1
],
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
].
cuda
(
),
bnb_optimizer
.
state
[
p2
][
name2
].
to
(
device
),
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
)
)
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 1
0
errors for Lion
# allow up to 1
5
errors for Lion
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
1
0
)
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
1
5
)
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
path
=
get_temp_dir
()
...
@@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
def
test_global_config
(
requires_cuda
,
dim1
,
dim2
,
gtype
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
))
def
test_global_config
(
dim1
,
dim2
,
gtype
,
device
):
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
...
@@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
...
@@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
p1
=
p1
.
cuda
(
)
p1
=
p1
.
to
(
device
)
p2
=
p2
.
cuda
(
)
p2
=
p2
.
to
(
device
)
p3
=
p3
.
cuda
(
)
p3
=
p3
.
to
(
device
)
adam2
=
bnb
.
optim
.
Adam
([
p1
,
p2
,
p3
],
lr
,
(
beta1
,
beta2
),
eps
)
adam2
=
bnb
.
optim
.
Adam
([
p1
,
p2
,
p3
],
lr
,
(
beta1
,
beta2
),
eps
)
...
@@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
...
@@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
p1
.
grad
=
g1
p1
.
grad
=
g1
p2
.
grad
=
g2
p2
.
grad
=
g2
p3
.
grad
=
g3
p3
.
grad
=
g3
...
@@ -302,13 +305,14 @@ optimizer_names_8bit = [
...
@@ -302,13 +305,14 @@ optimizer_names_8bit = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
def
test_optimizer8bit
(
requires_cuda
,
dim1
,
dim2
,
gtype
,
optim_name
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
))
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
torch
.
set_printoptions
(
precision
=
6
)
torch
.
set_printoptions
(
precision
=
6
)
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
p1
=
p1
.
float
()
blocksize
=
256
blocksize
=
256
...
@@ -330,15 +334,15 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -330,15 +334,15 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
relerrors
=
[]
relerrors
=
[]
for
i
in
range
(
50
):
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
bnb_optimizer
.
step
()
torch_optimizer
.
step
()
torch_optimizer
.
step
()
bnb_optimizer
.
step
()
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close
(
p1
,
p2
.
float
(),
patol
,
prtol
,
max_error_count
=
0
)
#
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
dequant_states
=
[]
dequant_states
=
[]
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
...
@@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
...
@@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
)
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
#
assert num_not_close.sum().item() < 20
assert
num_not_close
.
sum
().
item
()
<
20
dequant_states
.
append
(
s1
.
clone
())
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
...
@@ -549,25 +553,25 @@ optimizer_names_benchmark = [
...
@@ -549,25 +553,25 @@ optimizer_names_benchmark = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_benchmark
,
ids
=
id_formatter
(
"opt"
))
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_benchmark
,
ids
=
id_formatter
(
"opt"
))
@
pytest
.
mark
.
benchmark
@
pytest
.
mark
.
benchmark
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
p1
.
grad
=
g
total_steps
=
500
total_steps
=
500
for
i
in
range
(
total_steps
):
for
i
in
range
(
total_steps
):
if
i
==
total_steps
//
5
:
if
i
==
total_steps
//
5
:
# 100 iterations for burn-in
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p1
)
t0
=
time
.
time
()
t0
=
time
.
time
()
bnb_optimizer
.
step
()
bnb_optimizer
.
step
()
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p1
)
s
=
time
.
time
()
-
t0
s
=
time
.
time
()
-
t0
print
(
""
)
print
(
""
)
params
=
(
total_steps
-
total_steps
//
5
)
*
dim1
*
dim2
params
=
(
total_steps
-
total_steps
//
5
)
*
dim1
*
dim2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment