Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4b025748
Commit
4b025748
authored
Sep 15, 2025
by
Matthew Douglas
Browse files
Lint fix
parent
1813b058
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
161 additions
and
64 deletions
+161
-64
bitsandbytes/backends/default/ops.py
bitsandbytes/backends/default/ops.py
+37
-12
bitsandbytes/backends/triton/kernels_optim.py
bitsandbytes/backends/triton/kernels_optim.py
+123
-51
bitsandbytes/backends/triton/ops.py
bitsandbytes/backends/triton/ops.py
+1
-1
No files found.
bitsandbytes/backends/default/ops.py
View file @
4b025748
...
@@ -320,6 +320,7 @@ name2optimizer_id = {
...
@@ -320,6 +320,7 @@ name2optimizer_id = {
"ademamix"
:
ADEMAMIX
,
"ademamix"
:
ADEMAMIX
,
}
}
@
torch
.
compile
@
torch
.
compile
def
_optimizer_precondition_32bit
(
def
_optimizer_precondition_32bit
(
g
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
...
@@ -525,29 +526,53 @@ def _(
...
@@ -525,29 +526,53 @@ def _(
if
optimizer_name
==
"lion"
:
if
optimizer_name
==
"lion"
:
_optimizer_update_32bit
(
_optimizer_update_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
lr
,
gnorm_scale
,
optimizer_id
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
,
)
)
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
_optimizer_precondition_32bit
(
_optimizer_precondition_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
)
)
else
:
else
:
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
_optimizer_precondition_32bit
(
_optimizer_precondition_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
p
,
state1
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
)
)
_optimizer_update_32bit
(
_optimizer_update_32bit
(
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
lr
,
gnorm_scale
,
optimizer_id
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
optimizer_id
,
)
)
bitsandbytes/backends/triton/kernels_optim.py
View file @
4b025748
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
# from triton.language.extra import libdevice
# from triton.language.extra import libdevice
MOMENTUM
=
0
MOMENTUM
=
0
...
@@ -23,6 +24,7 @@ name2optimizer_id = {
...
@@ -23,6 +24,7 @@ name2optimizer_id = {
"ademamix"
:
ADEMAMIX
,
"ademamix"
:
ADEMAMIX
,
}
}
@
triton
.
jit
@
triton
.
jit
def
_optimizer_precondition_2state_32bit
(
def
_optimizer_precondition_2state_32bit
(
g_ptr
,
g_ptr
,
...
@@ -49,32 +51,32 @@ def _optimizer_precondition_2state_32bit(
...
@@ -49,32 +51,32 @@ def _optimizer_precondition_2state_32bit(
block_start_idx
=
pid
*
N_PER_TH
block_start_idx
=
pid
*
N_PER_TH
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s2_vals
=
tl
.
load
(
state2_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s2_vals
=
tl
.
load
(
state2_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
g_vals
=
gnorm_scale
*
g_vals
g_vals
=
gnorm_scale
*
g_vals
correction1
=
1.0
/
(
1.0
-
beta1_step
)
correction1
=
1.0
/
(
1.0
-
beta1_step
)
correction2
=
1.0
/
(
1.0
-
beta2_step
)
correction2
=
1.0
/
(
1.0
-
beta2_step
)
if
OPTIMIZER_ID
==
3
:
# ADAM
if
OPTIMIZER_ID
==
3
:
# ADAM
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
s1_vals
=
s1_vals
*
correction1
s1_vals
=
s1_vals
*
correction1
s2_vals
=
s2_vals
*
correction2
s2_vals
=
s2_vals
*
correction2
update_vals
=
s1_vals
/
(
tl
.
sqrt
(
s2_vals
)
+
eps
)
update_vals
=
s1_vals
/
(
tl
.
sqrt
(
s2_vals
)
+
eps
)
update_norm
=
update_vals
*
update_vals
update_norm
=
update_vals
*
update_vals
elif
OPTIMIZER_ID
==
5
:
# ADEMAMIX
elif
OPTIMIZER_ID
==
5
:
# ADEMAMIX
update_norm
=
s1_vals
update_norm
=
s1_vals
total_norm
=
tl
.
sum
(
tl
.
where
(
mask
,
update_norm
,
0.0
))
total_norm
=
tl
.
sum
(
tl
.
where
(
mask
,
update_norm
,
0.0
))
tl
.
atomic_add
(
unorm_ptr
,
total_norm
)
tl
.
atomic_add
(
unorm_ptr
,
total_norm
)
...
@@ -89,7 +91,7 @@ def _optimizer_precondition_1state_32bit(
...
@@ -89,7 +91,7 @@ def _optimizer_precondition_1state_32bit(
beta2
:
tl
.
constexpr
,
beta2
:
tl
.
constexpr
,
eps
:
tl
.
constexpr
,
eps
:
tl
.
constexpr
,
weight_decay
,
weight_decay
,
step
,
step
,
beta1_step
,
beta1_step
,
beta2_step
,
beta2_step
,
lr
,
lr
,
...
@@ -104,12 +106,12 @@ def _optimizer_precondition_1state_32bit(
...
@@ -104,12 +106,12 @@ def _optimizer_precondition_1state_32bit(
block_start_idx
=
pid
*
N_PER_TH
block_start_idx
=
pid
*
N_PER_TH
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
g_vals
=
gnorm_scale
*
g_vals
g_vals
=
gnorm_scale
*
g_vals
if
OPTIMIZER_ID
==
0
:
# MOMENTUM
if
OPTIMIZER_ID
==
0
:
# MOMENTUM
if
step
==
1
:
if
step
==
1
:
s1_vals
=
g_vals
s1_vals
=
g_vals
...
@@ -130,9 +132,9 @@ def _optimizer_precondition_1state_32bit(
...
@@ -130,9 +132,9 @@ def _optimizer_precondition_1state_32bit(
s1_vals
=
s1_vals
+
g_vals
*
g_vals
s1_vals
=
s1_vals
+
g_vals
*
g_vals
update_vals
=
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
update_vals
=
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
update_norm
=
update_vals
*
update_vals
update_norm
=
update_vals
*
update_vals
total_norm
=
tl
.
sum
(
tl
.
where
(
mask
,
update_norm
,
0.0
))
total_norm
=
tl
.
sum
(
tl
.
where
(
mask
,
update_norm
,
0.0
))
tl
.
atomic_add
(
unorm_ptr
,
total_norm
)
tl
.
atomic_add
(
unorm_ptr
,
total_norm
)
...
@@ -151,7 +153,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
...
@@ -151,7 +153,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
alpha
,
alpha
,
eps
:
tl
.
constexpr
,
eps
:
tl
.
constexpr
,
weight_decay
:
tl
.
constexpr
,
weight_decay
:
tl
.
constexpr
,
step
,
step
,
beta1_step
,
beta1_step
,
beta2_step
,
beta2_step
,
lr
,
lr
,
...
@@ -167,23 +169,23 @@ def _optimizer_update_2state_32bit_triton_kernel(
...
@@ -167,23 +169,23 @@ def _optimizer_update_2state_32bit_triton_kernel(
block_start_idx
=
pid
*
N_PER_TH
block_start_idx
=
pid
*
N_PER_TH
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
p_vals
=
tl
.
load
(
p_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
p_vals
=
tl
.
load
(
p_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s2_vals
=
tl
.
load
(
state2_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s2_vals
=
tl
.
load
(
state2_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
if
OPTIMIZER_ID
==
5
:
# ADEMAMIX
if
OPTIMIZER_ID
==
5
:
# ADEMAMIX
s3_vals
=
tl
.
load
(
state1_ptr
+
n_elements
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s3_vals
=
tl
.
load
(
state1_ptr
+
n_elements
+
offsets
,
mask
=
mask
,
other
=
0.0
)
g_vals
=
gnorm_scale
*
g_vals
g_vals
=
gnorm_scale
*
g_vals
update_scale
=
1.0
update_scale
=
1.0
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
current_unorm
=
tl
.
sqrt
(
tl
.
load
(
unorm_ptr
))
current_unorm
=
tl
.
sqrt
(
tl
.
load
(
unorm_ptr
))
if
current_unorm
>
max_unorm
*
param_norm
:
if
current_unorm
>
max_unorm
*
param_norm
:
update_scale
=
(
max_unorm
*
param_norm
)
/
current_unorm
update_scale
=
(
max_unorm
*
param_norm
)
/
current_unorm
if
OPTIMIZER_ID
==
3
:
# ADAM
if
OPTIMIZER_ID
==
3
:
# ADAM
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
...
@@ -197,8 +199,8 @@ def _optimizer_update_2state_32bit_triton_kernel(
...
@@ -197,8 +199,8 @@ def _optimizer_update_2state_32bit_triton_kernel(
update_val
=
update_scale
*
step_size
*
(
s1_vals
/
(
tl
.
sqrt
(
s2_vals
)
+
eps
*
correction2
))
update_val
=
update_scale
*
step_size
*
(
s1_vals
/
(
tl
.
sqrt
(
s2_vals
)
+
eps
*
correction2
))
p_vals
=
p_vals
+
update_val
p_vals
=
p_vals
+
update_val
elif
OPTIMIZER_ID
==
5
:
# ADEMAMIX
elif
OPTIMIZER_ID
==
5
:
# ADEMAMIX
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
# m1
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
# m1
s3_vals
=
s3_vals
*
beta3
+
(
1.0
-
beta3
)
*
g_vals
# m2
s3_vals
=
s3_vals
*
beta3
+
(
1.0
-
beta3
)
*
g_vals
# m2
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
# nu
s2_vals
=
s2_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
*
g_vals
# nu
...
@@ -208,15 +210,15 @@ def _optimizer_update_2state_32bit_triton_kernel(
...
@@ -208,15 +210,15 @@ def _optimizer_update_2state_32bit_triton_kernel(
if
weight_decay
>
0.0
:
if
weight_decay
>
0.0
:
p_vals
=
p_vals
*
(
1.0
-
lr
*
weight_decay
)
p_vals
=
p_vals
*
(
1.0
-
lr
*
weight_decay
)
mixed_momentum
=
(
s1_vals
/
correction1
)
+
(
alpha
*
s3_vals
)
mixed_momentum
=
(
s1_vals
/
correction1
)
+
(
alpha
*
s3_vals
)
adaptive_term
=
(
tl
.
sqrt
(
s2_vals
)
/
correction2
)
+
eps
adaptive_term
=
(
tl
.
sqrt
(
s2_vals
)
/
correction2
)
+
eps
p_vals
=
p_vals
-
lr
*
(
mixed_momentum
/
adaptive_term
)
p_vals
=
p_vals
-
lr
*
(
mixed_momentum
/
adaptive_term
)
tl
.
store
(
p_ptr
+
offsets
,
p_vals
,
mask
=
mask
)
tl
.
store
(
p_ptr
+
offsets
,
p_vals
,
mask
=
mask
)
tl
.
store
(
state1_ptr
+
offsets
,
s1_vals
,
mask
=
mask
)
tl
.
store
(
state1_ptr
+
offsets
,
s1_vals
,
mask
=
mask
)
tl
.
store
(
state2_ptr
+
offsets
,
s2_vals
,
mask
=
mask
)
tl
.
store
(
state2_ptr
+
offsets
,
s2_vals
,
mask
=
mask
)
if
OPTIMIZER_ID
==
5
:
# ADEMAMIX
if
OPTIMIZER_ID
==
5
:
# ADEMAMIX
tl
.
store
(
state1_ptr
+
n_elements
+
offsets
,
s3_vals
,
mask
=
mask
)
tl
.
store
(
state1_ptr
+
n_elements
+
offsets
,
s3_vals
,
mask
=
mask
)
...
@@ -224,7 +226,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
...
@@ -224,7 +226,7 @@ def _optimizer_update_2state_32bit_triton_kernel(
@
triton
.
jit
@
triton
.
jit
def
_optimizer_update_1state_32bit_triton_kernel
(
def
_optimizer_update_1state_32bit_triton_kernel
(
g_ptr
,
g_ptr
,
p_ptr
,
p_ptr
,
state1_ptr
,
state1_ptr
,
state2_ptr
,
state2_ptr
,
unorm_ptr
,
unorm_ptr
,
...
@@ -252,7 +254,7 @@ def _optimizer_update_1state_32bit_triton_kernel(
...
@@ -252,7 +254,7 @@ def _optimizer_update_1state_32bit_triton_kernel(
block_start_idx
=
pid
*
N_PER_TH
block_start_idx
=
pid
*
N_PER_TH
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
offsets
=
block_start_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
*
N_PER_TH
)
mask
=
offsets
<
n_elements
mask
=
offsets
<
n_elements
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
g_vals
=
tl
.
load
(
g_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
p_vals
=
tl
.
load
(
p_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
p_vals
=
tl
.
load
(
p_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
s1_vals
=
tl
.
load
(
state1_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
...
@@ -260,19 +262,19 @@ def _optimizer_update_1state_32bit_triton_kernel(
...
@@ -260,19 +262,19 @@ def _optimizer_update_1state_32bit_triton_kernel(
g_vals
=
gnorm_scale
*
g_vals
g_vals
=
gnorm_scale
*
g_vals
if
weight_decay
>
0.0
:
if
weight_decay
>
0.0
:
g_vals
=
g_vals
+
p_vals
*
weight_decay
g_vals
=
g_vals
+
p_vals
*
weight_decay
update_scale
=
1.0
update_scale
=
1.0
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
current_unorm
=
tl
.
sqrt
(
tl
.
load
(
unorm_ptr
))
current_unorm
=
tl
.
sqrt
(
tl
.
load
(
unorm_ptr
))
if
current_unorm
>
max_unorm
*
param_norm
+
eps
:
if
current_unorm
>
max_unorm
*
param_norm
+
eps
:
update_scale
=
(
max_unorm
*
param_norm
+
eps
)
/
current_unorm
update_scale
=
(
max_unorm
*
param_norm
+
eps
)
/
current_unorm
if
OPTIMIZER_ID
==
0
:
# MOMENTUM
if
OPTIMIZER_ID
==
0
:
# MOMENTUM
if
step
==
1
:
if
step
==
1
:
s1_vals
=
g_vals
s1_vals
=
g_vals
else
:
else
:
s1_vals
=
s1_vals
*
beta1
+
g_vals
s1_vals
=
s1_vals
*
beta1
+
g_vals
update_val
=
update_scale
*
(
-
lr
*
s1_vals
)
update_val
=
update_scale
*
(
-
lr
*
s1_vals
)
p_vals
=
p_vals
+
update_val
p_vals
=
p_vals
+
update_val
...
@@ -280,21 +282,21 @@ def _optimizer_update_1state_32bit_triton_kernel(
...
@@ -280,21 +282,21 @@ def _optimizer_update_1state_32bit_triton_kernel(
momentum_update
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
momentum_update
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
update_val
=
update_scale
*
lr
*
tl
.
where
(
momentum_update
>
0
,
1.0
,
tl
.
where
(
momentum_update
<
0
,
-
1.0
,
0.0
))
update_val
=
update_scale
*
lr
*
tl
.
where
(
momentum_update
>
0
,
1.0
,
tl
.
where
(
momentum_update
<
0
,
-
1.0
,
0.0
))
p_vals
=
p_vals
-
update_val
p_vals
=
p_vals
-
update_val
s1_vals
=
s1_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
s1_vals
=
s1_vals
*
beta2
+
(
1.0
-
beta2
)
*
g_vals
elif
OPTIMIZER_ID
==
1
:
# RMSPROP
elif
OPTIMIZER_ID
==
1
:
# RMSPROP
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
*
g_vals
s1_vals
=
s1_vals
*
beta1
+
(
1.0
-
beta1
)
*
g_vals
*
g_vals
update_val
=
update_scale
*
lr
*
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
update_val
=
update_scale
*
lr
*
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
p_vals
=
p_vals
-
update_val
p_vals
=
p_vals
-
update_val
elif
OPTIMIZER_ID
==
2
:
# ADAGRAD
elif
OPTIMIZER_ID
==
2
:
# ADAGRAD
s1_vals
=
s1_vals
+
g_vals
*
g_vals
s1_vals
=
s1_vals
+
g_vals
*
g_vals
update_val
=
lr
*
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
update_val
=
lr
*
g_vals
/
(
tl
.
sqrt
(
s1_vals
)
+
eps
)
p_vals
=
p_vals
-
update_val
p_vals
=
p_vals
-
update_val
tl
.
store
(
p_ptr
+
offsets
,
p_vals
,
mask
=
mask
)
tl
.
store
(
p_ptr
+
offsets
,
p_vals
,
mask
=
mask
)
tl
.
store
(
state1_ptr
+
offsets
,
s1_vals
,
mask
=
mask
)
tl
.
store
(
state1_ptr
+
offsets
,
s1_vals
,
mask
=
mask
)
...
@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
...
@@ -367,34 +369,104 @@ def optimizer_update_32bit_impl(
if
optimizer_name
==
"lion"
:
if
optimizer_name
==
"lion"
:
fn_update
[
grid
](
fn_update
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
fn_preprocess
[
grid
](
fn_preprocess
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
else
:
else
:
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
unorm_vec
.
zero_
()
unorm_vec
.
zero_
()
fn_preprocess
[
grid
](
fn_preprocess
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
g
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
fn_update
[
grid
](
fn_update
[
grid
](
g
,
p
,
state1
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
g
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
p
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
state1
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
state2
,
unorm_vec
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
beta3
,
alpha
,
eps
,
weight_decay
,
step
,
beta1_step
,
beta2_step
,
lr
,
gnorm_scale
,
skip_zeros
,
p
.
numel
(),
optimizer_id
,
BLOCK_SIZE
,
N_PER_TH
,
num_warps
=
2
,
)
)
bitsandbytes/backends/triton/ops.py
View file @
4b025748
...
@@ -3,7 +3,7 @@ from typing import Optional
...
@@ -3,7 +3,7 @@ from typing import Optional
import
torch
import
torch
from
.
import
triton_kernels
,
kernels
_optim
from
.
import
kernels_optim
,
triton_
kernels
# currently codes unused, kept for reference
# currently codes unused, kept for reference
# Should be the same for quant/dequant
# Should be the same for quant/dequant
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment