Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2eb31083
Commit
2eb31083
authored
Apr 11, 2023
by
Tim Dettmers
Browse files
Fixed bug where beta2 was not passed into Lion 32-bit.
parent
792af5c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+1
-1
tests/test_optim.py
tests/test_optim.py
+14
-10
No files found.
bitsandbytes/optim/optimizer.py
View file @
2eb31083
...
...
@@ -665,7 +665,7 @@ class Optimizer1State(Optimizer8bit):
step
,
config
[
"lr"
],
None
,
0.0
,
config
[
'betas'
][
1
]
,
config
[
"weight_decay"
],
gnorm_scale
,
state
[
"unorm_vec"
]
if
config
[
"max_unorm"
]
>
0.0
else
None
,
...
...
tests/test_optim.py
View file @
2eb31083
...
...
@@ -22,7 +22,7 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
error_count
=
(
idx
==
0
).
sum
().
item
()
if
error_count
>
max_error_count
:
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
print
(
f
"Too many values not close: assert
{
error_count
}
<
{
max_error_
count
}
"
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
...
...
@@ -170,6 +170,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
.
step
()
torch_optimizer
.
step
()
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
...
...
@@ -178,7 +179,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol
=
rtol
,
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
,
rtol
,
max_error_count
=
10
)
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
...
...
@@ -188,14 +191,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close
(
p1
,
p2
.
float
(),
atol
,
rtol
,
max_error_count
=
10
)
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
)
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 10 errors for Lion
assert_most_approx_close
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
10
)
if
gtype
==
torch
.
float16
:
# the adam buffers should also be close because they are 32-bit
...
...
@@ -343,7 +347,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment