Unverified Commit 61db0859 authored by वेदांत's avatar वेदांत Committed by GitHub
Browse files

doc fix signature for 8-bit optim (#1660)

* doc fix signature for 8-bit optim

* required changes

* precommit
parent df73d3e1
......@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
......@@ -100,8 +99,10 @@ class Adam8bit(Optimizer2State):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in Adam8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
......@@ -113,6 +114,15 @@ class Adam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("Adam8bit does not support amsgrad=True")
if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since Adam8bit always uses 8-bit optimization
raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)")
super().__init__(
"adam",
params,
......@@ -120,7 +130,7 @@ class Adam8bit(Optimizer2State):
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
......@@ -283,8 +293,10 @@ class PagedAdam8bit(Optimizer2State):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in PagedAdam8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
......@@ -296,6 +308,15 @@ class PagedAdam8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("PagedAdam8bit does not support amsgrad=True")
if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)")
super().__init__(
"adam",
params,
......@@ -303,7 +324,7 @@ class PagedAdam8bit(Optimizer2State):
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
......
......@@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
......@@ -98,8 +99,10 @@ class AdamW8bit(Optimizer2State):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in AdamW8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
......@@ -111,6 +114,15 @@ class AdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("AdamW8bit does not support amsgrad=True")
if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")
super().__init__(
"adam",
params,
......@@ -118,7 +130,7 @@ class AdamW8bit(Optimizer2State):
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
......@@ -279,8 +291,10 @@ class PagedAdamW8bit(Optimizer2State):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
Note: This parameter is not supported in PagedAdamW8bit and must be False.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
args (`object`, defaults to `None`):
An object with additional arguments.
min_8bit_size (`int`, defaults to 4096):
......@@ -292,6 +306,15 @@ class PagedAdamW8bit(Optimizer2State):
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
raise ValueError("PagedAdamW8bit does not support amsgrad=True")
if optim_bits != 32:
# We allow the default value of 32 to maintain compatibility with the function signature,
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")
super().__init__(
"adam",
params,
......@@ -299,7 +322,7 @@ class PagedAdamW8bit(Optimizer2State):
betas,
eps,
weight_decay,
8,
8, # Hardcoded to 8 bits
args,
min_8bit_size,
percentile_clipping,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment