Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
aca9778e
Unverified
Commit
aca9778e
authored
Jun 24, 2025
by
Aman Gupta
Committed by
GitHub
Jun 24, 2025
Browse files
Make minor improvements to optimizer.py (#1687)
parent
fd2949ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
bitsandbytes/optim/optimizer.py
bitsandbytes/optim/optimizer.py
+6
-8
No files found.
bitsandbytes/optim/optimizer.py
View file @
aca9778e
...
@@ -64,9 +64,9 @@ class GlobalOptimManager:
...
@@ -64,9 +64,9 @@ class GlobalOptimManager:
parameters (`torch.Tensor` or `list(torch.Tensors)`):
parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters.
The input parameters.
key (`str`):
key (`str`):
The hyperparamter to override.
The hyperparam
e
ter to override.
value:
value:
The hyperparameter value
s
.
The hyperparameter value.
key_value_dict (`dict`):
key_value_dict (`dict`):
A dictionary with multiple key-values to override.
A dictionary with multiple key-values to override.
...
@@ -115,7 +115,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -115,7 +115,7 @@ class Optimizer8bit(torch.optim.Optimizer):
Base 8-bit optimizer class.
Base 8-bit optimizer class.
Arguments:
Arguments:
params (`torch.
t
ensor`):
params (`torch.
T
ensor`):
The input parameters to optimize.
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
The number of bits of the optimizer state.
...
@@ -291,7 +291,7 @@ class Optimizer8bit(torch.optim.Optimizer):
...
@@ -291,7 +291,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
self
.
update_step
(
group
,
p
,
gindex
,
pindex
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
if
self
.
is_paged
:
if
self
.
is_paged
:
# all paged operation are asynchronous, we need
# all paged operation
s
are asynchronous, we need
# to sync to make sure all tensors are in the right state
# to sync to make sure all tensors are in the right state
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -371,7 +371,7 @@ class Optimizer2State(Optimizer8bit):
...
@@ -371,7 +371,7 @@ class Optimizer2State(Optimizer8bit):
Arguments:
Arguments:
optimizer_name (`str`):
optimizer_name (`str`):
The name of the optimizer.
The name of the optimizer.
params (`torch.
t
ensor`):
params (`torch.
T
ensor`):
The input parameters to optimize.
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
lr (`float`, defaults to 1e-3):
The learning rate.
The learning rate.
...
@@ -428,7 +428,6 @@ class Optimizer2State(Optimizer8bit):
...
@@ -428,7 +428,6 @@ class Optimizer2State(Optimizer8bit):
if
args
is
None
:
if
args
is
None
:
args
=
{}
args
=
{}
args
[
"optim_bits"
]
=
optim_bits
args
[
"optim_bits"
]
=
optim_bits
args
[
"percentile_clipping"
]
=
100
args
[
"min_8bit_size"
]
=
min_8bit_size
args
[
"min_8bit_size"
]
=
min_8bit_size
args
[
"percentile_clipping"
]
=
percentile_clipping
args
[
"percentile_clipping"
]
=
percentile_clipping
args
[
"block_wise"
]
=
block_wise
args
[
"block_wise"
]
=
block_wise
...
@@ -613,7 +612,7 @@ class Optimizer1State(Optimizer8bit):
...
@@ -613,7 +612,7 @@ class Optimizer1State(Optimizer8bit):
Arguments:
Arguments:
optimizer_name (`str`):
optimizer_name (`str`):
The name of the optimizer.
The name of the optimizer.
params (`torch.
t
ensor`):
params (`torch.
T
ensor`):
The input parameters to optimize.
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
lr (`float`, defaults to 1e-3):
The learning rate.
The learning rate.
...
@@ -655,7 +654,6 @@ class Optimizer1State(Optimizer8bit):
...
@@ -655,7 +654,6 @@ class Optimizer1State(Optimizer8bit):
if
args
is
None
:
if
args
is
None
:
args
=
{}
args
=
{}
args
[
"optim_bits"
]
=
optim_bits
args
[
"optim_bits"
]
=
optim_bits
args
[
"percentile_clipping"
]
=
100
args
[
"min_8bit_size"
]
=
min_8bit_size
args
[
"min_8bit_size"
]
=
min_8bit_size
args
[
"percentile_clipping"
]
=
percentile_clipping
args
[
"percentile_clipping"
]
=
percentile_clipping
args
[
"block_wise"
]
=
block_wise
args
[
"block_wise"
]
=
block_wise
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment