Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1f2ca43a
Unverified
Commit
1f2ca43a
authored
May 30, 2024
by
Titus
Committed by
GitHub
May 30, 2024
Browse files
Merge pull request #1222 from EtienneDosSantos/main
Add erroneously missing optimizers to `str2optimizer32bit`
parents
d9b1125c
7a338db2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
33 deletions
+89
-33
bitsandbytes/functional.py
bitsandbytes/functional.py
+89
-33
No files found.
bitsandbytes/functional.py
View file @
1f2ca43a
...
...
@@ -27,79 +27,135 @@ name2qmap = {}
if
lib
and
lib
.
compiled_with_cuda
:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{
"adagrad"
:
(
lib
.
cadagrad32bit_grad_fp32
,
lib
.
cadagrad32bit_grad_fp16
,
),
"adam"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum32bit_grad_32
,
lib
.
cmomentum32bit_grad_16
,
"pagedadam"
:
(
lib
.
cpagedadam32bit_grad_fp32
,
lib
.
cpagedadam32bit_grad_fp16
,
lib
.
cpagedadam32bit_grad_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_32
,
lib
.
crmsprop32bit_grad_16
,
"adamw"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
lib
.
cadam32bit_grad_bf16
,
),
"pagedadamw"
:
(
lib
.
cpagedadam32bit_grad_fp32
,
lib
.
cpagedadam32bit_grad_fp16
,
lib
.
cpagedadam32bit_grad_bf16
,
),
"lamb"
:
(
lib
.
cadam32bit_grad_fp32
,
lib
.
cadam32bit_grad_fp16
,
),
"lars"
:
(
lib
.
clars32bit_grad_fp32
,
lib
.
clars32bit_grad_fp16
,
),
"lion"
:
(
lib
.
clion32bit_grad_fp32
,
lib
.
clion32bit_grad_fp16
,
lib
.
clion32bit_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad32bit_grad_32
,
lib
.
cadagrad32bit_grad_16
,
"momentum"
:
(
lib
.
cmomentum32bit_grad_fp32
,
lib
.
cmomentum32bit_grad_fp16
,
),
"rmsprop"
:
(
lib
.
crmsprop32bit_grad_fp32
,
lib
.
crmsprop32bit_grad_fp16
,
),
}
str2optimizer8bit
=
{
"adagrad"
:
(
lib
.
cadagrad8bit_grad_fp32
,
lib
.
cadagrad8bit_grad_fp16
,
),
"adam"
:
(
lib
.
cadam_static_8bit_grad_32
,
lib
.
cadam_static_8bit_grad_16
,
lib
.
cadam_static_8bit_grad_
fp
32
,
lib
.
cadam_static_8bit_grad_
fp
16
,
),
"momentum"
:
(
lib
.
cmomentum_static_8bit_grad_32
,
lib
.
cmomentum_static_8bit_grad_16
,
"pagedadam"
:
(
lib
.
cpagedadam8bit_grad_fp32
,
lib
.
cpagedadam8bit_grad_fp16
,
lib
.
cpagedadam8bit_grad_bf16
,
),
"
rmsprop
"
:
(
lib
.
c
rmsprop
_static_8bit_grad_32
,
lib
.
c
rmsprop
_static_8bit_grad_16
,
"
adamw
"
:
(
lib
.
c
adam
_static_8bit_grad_
fp
32
,
lib
.
c
adam
_static_8bit_grad_
fp
16
,
),
"lion"
:
(
lib
.
clion_static_8bit_grad_32
,
lib
.
clion_static_8bit_grad_16
,
"pagedadamw"
:
(
lib
.
cpagedadam8bit_grad_fp32
,
lib
.
cpagedadam8bit_grad_fp16
,
lib
.
cpagedadam8bit_grad_bf16
,
),
"lamb"
:
(
lib
.
cadam_static_8bit_grad_32
,
lib
.
cadam_static_8bit_grad_16
,
lib
.
cadam_static_8bit_grad_
fp
32
,
lib
.
cadam_static_8bit_grad_
fp
16
,
),
"lars"
:
(
lib
.
cmomentum_static_8bit_grad_32
,
lib
.
cmomentum_static_8bit_grad_16
,
lib
.
clars8bit_grad_fp32
,
lib
.
clars8bit_grad_fp16
,
),
"lion"
:
(
lib
.
clion_static_8bit_grad_fp32
,
lib
.
clion_static_8bit_grad_fp16
,
),
"momentum"
:
(
lib
.
cmomentum_static_8bit_grad_fp32
,
lib
.
cmomentum_static_8bit_grad_fp16
,
),
"rmsprop"
:
(
lib
.
crmsprop_static_8bit_grad_fp32
,
lib
.
crmsprop_static_8bit_grad_fp16
,
),
}
str2optimizer8bit_blockwise
=
{
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
),
"adam"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
"pagedadam"
:
(
lib
.
cpagedadam8bit_blockwise_fp32
,
lib
.
cpagedadam8bit_blockwise_fp16
,
lib
.
cpagedadam8bit_blockwise_bf16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
"adamw"
:
(
lib
.
cadam_8bit_blockwise_grad_fp32
,
lib
.
cadam_8bit_blockwise_grad_fp16
,
lib
.
cadam_8bit_blockwise_grad_bf16
,
),
"pagedadamw"
:
(
lib
.
cpagedadam8bit_blockwise_fp32
,
lib
.
cpagedadam8bit_blockwise_fp16
,
lib
.
cpagedadam8bit_blockwise_bf16
,
),
"lion"
:
(
lib
.
clion_8bit_blockwise_grad_fp32
,
lib
.
clion_8bit_blockwise_grad_fp16
,
lib
.
clion_8bit_blockwise_grad_bf16
,
),
"adagrad"
:
(
lib
.
cadagrad_8bit_blockwise_grad_fp32
,
lib
.
cadagrad_8bit_blockwise_grad_fp16
,
"momentum"
:
(
lib
.
cmomentum_8bit_blockwise_grad_fp32
,
lib
.
cmomentum_8bit_blockwise_grad_fp16
,
),
"rmsprop"
:
(
lib
.
crmsprop_8bit_blockwise_grad_fp32
,
lib
.
crmsprop_8bit_blockwise_grad_fp16
,
),
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment