Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
941681da
Unverified
Commit
941681da
authored
Jul 14, 2025
by
Matthew Douglas
Committed by
GitHub
Jul 14, 2025
Browse files
Merge pull request #1706 from Egor-Krivov/egor/8bit_int
Add kernel registration for 8bit and 32bit optimizers
parents
adc7fda7
0f6fe6bf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
417 additions
and
172 deletions
+417
-172
bitsandbytes/_ops.py
bitsandbytes/_ops.py
+104
-0
bitsandbytes/backends/cuda/ops.py
bitsandbytes/backends/cuda/ops.py
+226
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+43
-141
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+5
-3
bitsandbytes/utils.py
bitsandbytes/utils.py
+7
-0
tests/helpers.py
tests/helpers.py
+3
-3
tests/test_optim.py
tests/test_optim.py
+29
-25
No files found.
bitsandbytes/_ops.py
View file @
941681da
...
...
@@ -348,3 +348,107 @@ if ipex_cpu or ipex_xpu:
)
->
torch
.
Tensor
:
torch
.
_check_is_size
(
blocksize
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_32bit"
,
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_32bit"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
library
.
define
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()"
,
)
@
register_fake
(
"bitsandbytes::optimizer_update_8bit_blockwise"
)
def
_
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
torch
.
_check
(
g
.
numel
()
==
p
.
numel
(),
lambda
:
f
"g and p must have the same number of elements, got
{
g
.
numel
()
}
and
{
p
.
numel
()
}
"
,
)
compute_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
torch
.
_check
(
g
.
dtype
in
compute_dtypes
,
lambda
:
f
"g must be bfloat16, float16, or float32, got
{
g
.
dtype
}
"
,
)
torch
.
_check
(
g
.
dtype
==
p
.
dtype
,
lambda
:
f
"Expected all tensors to have the same dtype, got g.dtype=
{
g
.
dtype
}
, p.dtype=
{
p
.
dtype
}
"
,
)
torch
.
_check
(
state1
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state1 must be uint8, got
{
state1
.
dtype
}
"
,
)
torch
.
_check
(
qmap1
.
dtype
==
absmax1
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap1 and absmax1 to be float32, got qmap1.dtype=
{
qmap1
.
dtype
}
, absmax1.dtype=
{
absmax1
.
dtype
}
"
,
)
if
state2
is
not
None
:
torch
.
_check
(
state2
.
dtype
==
torch
.
uint8
,
lambda
:
f
"state2 must be uint8, got
{
state2
.
dtype
}
"
,
)
torch
.
_check
(
qmap2
.
dtype
==
absmax2
.
dtype
==
torch
.
float32
,
lambda
:
f
"Expected qmap2 and absmax2 to be float32, got qmap2.dtype=
{
qmap2
.
dtype
}
, absmax2.dtype=
{
absmax2
.
dtype
}
"
,
)
bitsandbytes/backends/cuda/ops.py
View file @
941681da
...
...
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
ct
.
c_int32
(
blocksize
),
stream
,
)
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
def
_optimizer_update_32bit_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
unorm_vec
:
Optional
[
torch
.
Tensor
],
max_unorm
:
float
,
param_norm
:
float
,
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
weight_decay
:
float
,
step
:
int
,
lr
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
optim_fns
=
str2optimizer32bit
.
get
(
optimizer_name
,
None
)
if
optim_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
optim_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
optim_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
optim_fns
)
==
3
:
optim_func
=
optim_fns
[
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
def
_optimizer_update_8bit_blockwise_impl
(
optimizer_name
:
str
,
g
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
state1
:
torch
.
Tensor
,
state2
:
Optional
[
torch
.
Tensor
],
beta1
:
float
,
beta2
:
float
,
beta3
:
float
,
alpha
:
float
,
eps
:
float
,
step
:
int
,
lr
:
float
,
qmap1
:
torch
.
Tensor
,
qmap2
:
Optional
[
torch
.
Tensor
],
absmax1
:
torch
.
Tensor
,
absmax2
:
Optional
[
torch
.
Tensor
],
weight_decay
:
float
,
gnorm_scale
:
float
,
skip_zeros
=
False
,
)
->
None
:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns
=
str2optimizer8bit_blockwise
.
get
(
optimizer_name
)
if
optimizer_fns
is
None
:
raise
ValueError
(
f
"Unsupported optimizer name:
{
optimizer_name
}
. Supported optimizers:
{
list
(
str2optimizer8bit_blockwise
.
keys
())
}
"
)
if
g
.
dtype
==
torch
.
float32
:
optimizer_fn
=
optimizer_fns
[
0
]
elif
g
.
dtype
==
torch
.
float16
:
optimizer_fn
=
optimizer_fns
[
1
]
elif
g
.
dtype
==
torch
.
bfloat16
:
optimizer_fn
=
optimizer_fns
[
2
]
else
:
raise
ValueError
(
f
"Unsupported gradient dtype:
{
g
.
dtype
}
. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)
with
_cuda_device_of
(
g
):
optimizer_fn
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
register_kernel
(
"bitsandbytes::optimizer_update_8bit_blockwise"
,
"cuda"
)(
_optimizer_update_8bit_blockwise_impl
)
register_kernel
(
"bitsandbytes::optimizer_update_32bit"
,
"cuda"
)(
_optimizer_update_32bit_impl
)
bitsandbytes/functional.py
View file @
941681da
...
...
@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap
=
{}
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix32bit_grad_fp32
,
lib
.
cademamix32bit_grad_fp16
,
lib
.
cademamix32bit_grad_bf16
,
),
}
str2optimizer8bit
=
{
"adam"
:
(
lib
.
cadam_static_8bit_grad_32
,
...
...
@@ -82,39 +47,6 @@ str2optimizer8bit = {
),
}
str2optimizer8bit_blockwise
=
{
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
lib
.
cmomentum_8bit_blockwise_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
lib
.
crmsprop_8bit_blockwise_grad_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
lib
.
cadagrad_8bit_blockwise_grad_bf16
,
),
"ademamix"
:
(
lib
.
cademamix_8bit_blockwise_grad_fp32
,
lib
.
cademamix_8bit_blockwise_grad_fp16
,
lib
.
cademamix_8bit_blockwise_grad_bf16
,
),
}
class
GlobalPageManager
:
_instance
=
None
...
...
@@ -422,8 +354,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for
t
in
tensors
:
# NULL pointers and paged tensors are OK.
if
t
is
not
None
and
not
getattr
(
t
,
"is_paged"
,
False
):
on_gpu
&=
t
.
is_cuda
gpu_ids
.
add
(
t
.
device
.
index
)
on_gpu
&=
t
.
device
.
type
!=
"cpu"
gpu_ids
.
add
(
(
t
.
device
.
type
,
t
.
device
.
index
)
)
if
not
on_gpu
:
raise
RuntimeError
(
...
...
@@ -1252,41 +1184,27 @@ def optimizer_update_32bit(
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
1
]
elif
g
.
dtype
==
torch
.
bfloat16
and
len
(
str2optimizer32bit
[
optimizer_name
])
==
3
:
optim_func
=
str2optimizer32bit
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
torch
.
ops
.
bitsandbytes
.
optimizer_update_32bit
(
optimizer_name
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
)
@
deprecated
(
...
...
@@ -1449,45 +1367,29 @@ def optimizer_update_8bit_blockwise(
)
->
None
:
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
,
)
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
with
_cuda_device_of
(
g
):
optim_func
(
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
beta3
),
ct
.
c_float
(
alpha
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
torch
.
ops
.
bitsandbytes
.
optimizer_update_8bit_blockwise
(
optimizer_name
,
g
,
p
,
state1
,
state2
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
step
,
lr
,
qmap1
,
qmap2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
)
@
deprecated
(
"This function is deprecated and will be removed in a future release."
,
category
=
FutureWarning
)
...
...
bitsandbytes/optim/optimizer.py
View file @
941681da
...
...
@@ -10,6 +10,7 @@ from typing import Optional
import
torch
import
bitsandbytes.functional
as
F
from
bitsandbytes.utils
import
sync_gpu
class
MockArgs
:
...
...
@@ -279,6 +280,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
initialized
=
True
# if self.is_paged: self.page_mng.prefetch_all()
p
=
None
for
gindex
,
group
in
enumerate
(
self
.
param_groups
):
for
pindex
,
p
in
enumerate
(
group
[
"params"
]):
if
p
.
grad
is
None
:
...
...
@@ -289,11 +291,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
prefetch_state
(
p
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
(
)
if
self
.
is_paged
:
sync_gpu
(
p
)
if
self
.
is_paged
and
p
is
not
None
:
# all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p
)
return
loss
...
...
bitsandbytes/utils.py
View file @
941681da
...
...
@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
"row"
:
0
,
"col32"
:
1
,
"col_turing"
:
2
,
"col_ampere"
:
3
}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
=
{
val
:
name
for
(
name
,
val
)
in
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING
.
items
()}
def
sync_gpu
(
t
:
torch
.
Tensor
):
if
t
.
device
.
type
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
t
.
device
.
type
==
"xpu"
:
torch
.
xpu
.
synchronize
()
tests/helpers.py
View file @
941681da
...
...
@@ -18,12 +18,12 @@ BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (boo
@
functools
.
cache
def
get_available_devices
():
def
get_available_devices
(
no_cpu
=
False
):
if
"BNB_TEST_DEVICE"
in
os
.
environ
:
# If the environment variable is set, use it directly.
return
[
os
.
environ
[
"BNB_TEST_DEVICE"
]]
return
[
d
for
d
in
os
.
environ
[
"BNB_TEST_DEVICE"
]
if
d
.
lower
()
!=
"cpu"
]
devices
=
[]
if
HIP_ENVIRONMENT
else
[
"cpu"
]
devices
=
[]
if
HIP_ENVIRONMENT
else
[
"cpu"
]
if
not
no_cpu
else
[]
if
hasattr
(
torch
,
"accelerator"
):
# PyTorch 2.6+ - determine accelerator using agnostic API.
...
...
tests/test_optim.py
View file @
941681da
...
...
@@ -11,7 +11,8 @@ import torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
tests.helpers
import
describe_dtype
,
id_formatter
from
bitsandbytes.utils
import
sync_gpu
from
tests.helpers
import
describe_dtype
,
get_available_devices
,
id_formatter
# import apex
...
...
@@ -168,7 +169,8 @@ optimizer_names_32bit = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
,
1
],
ids
=
id_formatter
(
"dim2"
))
def
test_optimizer32bit
(
requires_cuda
,
dim1
,
dim2
,
gtype
,
optim_name
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
),
ids
=
id_formatter
(
"device"
))
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
if
optim_name
.
startswith
(
"paged_"
)
and
sys
.
platform
==
"win32"
:
pytest
.
skip
(
"Paged optimizers can have issues on Windows."
)
...
...
@@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
...
...
@@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
...
...
@@ -201,14 +203,14 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_close
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
].
cuda
(
),
bnb_optimizer
.
state
[
p2
][
name2
].
to
(
device
),
atol
=
atol
,
rtol
=
rtol
,
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 1
0
errors for Lion
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
1
0
)
# allow up to 1
5
errors for Lion
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
1
5
)
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
...
...
@@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
describe_dtype
)
def
test_global_config
(
requires_cuda
,
dim1
,
dim2
,
gtype
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
))
def
test_global_config
(
dim1
,
dim2
,
gtype
,
device
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
...
...
@@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"optim_bits"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
p1
=
p1
.
cuda
(
)
p2
=
p2
.
cuda
(
)
p3
=
p3
.
cuda
(
)
p1
=
p1
.
to
(
device
)
p2
=
p2
.
to
(
device
)
p3
=
p3
.
to
(
device
)
adam2
=
bnb
.
optim
.
Adam
([
p1
,
p2
,
p3
],
lr
,
(
beta1
,
beta2
),
eps
)
...
...
@@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
+
0.001
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
+
0.001
p1
.
grad
=
g1
p2
.
grad
=
g2
p3
.
grad
=
g3
...
...
@@ -302,13 +305,14 @@ optimizer_names_8bit = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"dim2"
,
[
32
,
1024
,
4097
],
ids
=
id_formatter
(
"dim2"
))
@
pytest
.
mark
.
parametrize
(
"dim1"
,
[
1024
],
ids
=
id_formatter
(
"dim1"
))
def
test_optimizer8bit
(
requires_cuda
,
dim1
,
dim2
,
gtype
,
optim_name
):
@
pytest
.
mark
.
parametrize
(
"device"
,
get_available_devices
(
no_cpu
=
True
))
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
torch
.
set_printoptions
(
precision
=
6
)
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
blocksize
=
256
...
...
@@ -330,15 +334,15 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
relerrors
=
[]
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
bnb_optimizer
.
step
()
torch_optimizer
.
step
()
bnb_optimizer
.
step
()
# since Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close
(
p1
,
p2
.
float
(),
patol
,
prtol
,
max_error_count
=
0
)
#
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
dequant_states
=
[]
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
...
...
@@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
#
assert num_not_close.sum().item() < 20
assert
num_not_close
.
sum
().
item
()
<
20
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
...
...
@@ -549,25 +553,25 @@ optimizer_names_benchmark = [
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
],
ids
=
describe_dtype
)
@
pytest
.
mark
.
parametrize
(
"optim_name"
,
optimizer_names_benchmark
,
ids
=
id_formatter
(
"opt"
))
@
pytest
.
mark
.
benchmark
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
,
device
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
device
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
total_steps
=
500
for
i
in
range
(
total_steps
):
if
i
==
total_steps
//
5
:
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p1
)
t0
=
time
.
time
()
bnb_optimizer
.
step
()
torch
.
cuda
.
synchronize
(
)
sync_gpu
(
p1
)
s
=
time
.
time
()
-
t0
print
(
""
)
params
=
(
total_steps
-
total_steps
//
5
)
*
dim1
*
dim2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment