Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a43cd200
Commit
a43cd200
authored
Mar 22, 2023
by
Phil Wang
Browse files
add some code in test_optim.py, although it seems to be failing
parent
9b656f46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
1 deletion
+23
-1
requirements.txt
requirements.txt
+1
-0
tests/test_optim.py
tests/test_optim.py
+22
-1
No files found.
requirements.txt
View file @
a43cd200
lion-pytorch
pytest
pytest
tests/test_optim.py
View file @
a43cd200
...
@@ -7,6 +7,8 @@ from itertools import product
...
@@ -7,6 +7,8 @@ from itertools import product
from
os.path
import
join
from
os.path
import
join
import
pytest
import
pytest
from
lion_pytorch
import
Lion
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
...
@@ -31,6 +33,7 @@ str2optimizers = {}
...
@@ -31,6 +33,7 @@ str2optimizers = {}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum_pytorch"
]
=
(
str2optimizers
[
"momentum_pytorch"
]
=
(
None
,
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
...
@@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
...
@@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
)
)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
"lion"
]
=
(
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum"
]
=
(
str2optimizers
[
"momentum"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
@@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
...
@@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
torch
.
optim
.
Adam
,
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
)
)
str2optimizers
[
"lion8bit"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"momentum8bit"
]
=
(
str2optimizers
[
"momentum8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
@@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
...
@@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch
.
optim
.
Adam
,
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
)
)
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
...
@@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
...
@@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
...
@@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
...
@@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
]
]
str2statenames
[
"lion8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)
]
str2statenames
[
"lamb8bit"
]
=
[
str2statenames
[
"lamb8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
...
@@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
...
@@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
]
str2statenames
[
"lion8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"momentum8bit"
]
=
[
str2statenames
[
"momentum8bit"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
]
]
...
@@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
...
@@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lion"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
...
@@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
...
@@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
optimizer_names
=
[
"adam8bit"
,
"adam8bit"
,
"lion8bit"
,
"momentum8bit"
,
"momentum8bit"
,
"rmsprop8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"adam8bit_blockwise"
,
"lion8bit_blockwise"
,
"lars8bit"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment