Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4b025748
Commit
4b025748
authored
Sep 15, 2025
by
Matthew Douglas
Browse files
Lint fix
parent
1813b058
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
161 additions
and
64 deletions
+161
-64
bitsandbytes/backends/default/ops.py
bitsandbytes/backends/default/ops.py
+37
-12
bitsandbytes/backends/triton/kernels_optim.py
bitsandbytes/backends/triton/kernels_optim.py
+123
-51
bitsandbytes/backends/triton/ops.py
bitsandbytes/backends/triton/ops.py
+1
-1
No files found.
bitsandbytes/backends/default/ops.py
View file @
4b025748
...
@@ -320,6 +320,7 @@ name2optimizer_id = {
...
@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix"
:
ADEMAMIX
,
"ademamix"
:
ADEMAMIX
,
}
}
@
torch
.
compile
@
torch
.
compile
def
_optimizer_precondition_32bit
(
def
_optimizer_precondition_32bit
(
g
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
...
@@ -525,29 +526,53 @@ def _(
...
@@ -525,29 +526,53 @@ def _(
if
optimizer_name
==
"lion"
:
if
optimizer_name
==
"lion"
:
_optimizer_update_32bit
(
_optimizer_update_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
lr
,
gnorm_scale
,
optimizer_id
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
,
)
)
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
_optimizer_precondition_32bit
(
_optimizer_precondition_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
)
)
else
:
else
:
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
_optimizer_precondition_32bit
(
_optimizer_precondition_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
)
)
_optimizer_update_32bit
(
_optimizer_update_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
lr
,
gnorm_scale
,
optimizer_id
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
,
)
)
bitsandbytes/backends/triton/kernels_optim.py
View file @
4b025748
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# from triton.language.extra import libdevice
# from triton.language.extra import libdevice
MOMENTUM
=
0
MOMENTUM
=
0
...
@@ -23,6 +24,7 @@ name2optimizer_id = {
...
@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix"
:
ADEMAMIX
,
"ademamix"
:
ADEMAMIX
,
}
}
@
triton
.
jit
@
triton
.
jit
def
_optimizer_precondition_2state_32bit
(
def
_optimizer_precondition_2state_32bit
(
g_ptr
,
g_ptr
,
...
@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
...
@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if
optimizer_name
==
"lion"
:
if
optimizer_name
==
"lion"
:
fn_update
[
grid
](
fn_update
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
fn_preprocess
[
grid
](
fn_preprocess
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
else
:
else
:
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
fn_preprocess
[
grid
](
fn_preprocess
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
fn_update
[
grid
](
fn_update
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
bitsandbytes/backends/triton/ops.py
View file @
4b025748
...
@@ -3,7 +3,7 @@ from typing import Optional
...
@@ -3,7 +3,7 @@ from typing import Optional
import
torch
import
torch
from
.
import
triton_kernels
,
kernels
_optim
from
.
import
kernels_optim
,
triton_
kernels
# currently codes unused, kept for reference
# currently codes unused, kept for reference
# Should be the same for quant/dequant
# Should be the same for quant/dequant
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment