Unverified Commit ac5d6ee6 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] implement API docs (#1075)



* optims

* fix path

* fix path

* mdx

* fix path

* toctree

* fix

* optimizer, adagrad

* add init

* add

* more apis

* params

* clarify

* run pre-commit hooks

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent 87e029bc
......@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848
# Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6
# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
......@@ -21,16 +21,7 @@ T = TypeVar("T", bound="torch.nn.Module")
class StableEmbedding(torch.nn.Embedding):
"""
Custom embedding layer designed for stable training in NLP tasks. The stable
embedding layer improves stability during optimization for models with word
embeddings, addressing issues related to the non-uniform distribution of input
tokens.
This stable embedding layer is initialized with Xavier uniform initialization,
followed by layer normalization. It is designed to support aggressive quantization,
addressing extreme gradient variations in non-uniform input distributions. The
stability of training is enhanced by using 32-bit optimizer states specifically
for this layer.
Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.
Example:
......@@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding):
```
Attributes:
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding.
norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.
Methods:
reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
Reference:
- [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf)
"""
def __init__(
self,
......@@ -71,14 +59,22 @@ class StableEmbedding(torch.nn.Embedding):
) -> None:
"""
Args:
num_embeddings (`int`): The number of unique embeddings (vocabulary size).
embedding_dim (`int`): The dimensionality of the embedding.
padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index.
max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option.
scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation.
sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients.
_weight (`Optional[Tensor]`): Pre-trained embeddings.
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
......@@ -131,6 +127,9 @@ class StableEmbedding(torch.nn.Embedding):
class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__(
self,
num_embeddings: int,
......@@ -143,6 +142,25 @@ class Embedding(torch.nn.Embedding):
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
"""
Args:
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__(
num_embeddings,
embedding_dim,
......@@ -416,7 +434,19 @@ class Linear4bit(nn.Linear):
class LinearFP4(Linear4bit):
"""
Implements the FP4 data type.
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
......@@ -432,6 +462,15 @@ class LinearNF4(Linear4bit):
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
......
......@@ -20,6 +20,33 @@ class Adagrad(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
Base Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
......@@ -62,6 +89,33 @@ class Adagrad8bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
8-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 8):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
......@@ -105,6 +159,33 @@ class Adagrad32bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
32-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
......
......@@ -16,31 +16,205 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged 32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class AnalysisAdam(torch.optim.Optimizer):
......
......@@ -8,30 +8,204 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
......@@ -23,6 +23,39 @@ class LAMB(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
"""
Base LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__(
"lamb",
params,
......@@ -56,6 +89,37 @@ class LAMB8bit(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
"""
8-bit LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__(
"lamb",
params,
......@@ -89,6 +153,37 @@ class LAMB32bit(Optimizer2State):
block_wise=False,
max_unorm=1.0,
):
"""
32-bit LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__(
"lamb",
params,
......
......@@ -23,6 +23,33 @@ class LARS(Optimizer1State):
percentile_clipping=100,
max_unorm=0.02,
):
"""
Base LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
......@@ -57,6 +84,31 @@ class LARS8bit(Optimizer1State):
percentile_clipping=100,
max_unorm=0.02,
):
"""
8-bit LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
......@@ -91,6 +143,31 @@ class LARS32bit(Optimizer1State):
percentile_clipping=100,
max_unorm=0.02,
):
"""
32-bit LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
......
......@@ -7,25 +7,165 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 8-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 32-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
......@@ -18,6 +18,9 @@ class MockArgs:
class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
"""
_instance = None
def __init__(self):
......@@ -53,22 +56,40 @@ class GlobalOptimManager:
self, parameters, key=None, value=None, key_value_dict=None
):
"""
Overrides initial optimizer config for specific parameters.
Override initial optimizer config with specific hyperparameters.
The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific parameters like "optim_bits", "percentile_clipping".
This can be both, optimizer parameters like `betas` or `lr`, or it can be
8-bit specific parameters like `optim_bits` or `percentile_clipping`.
Parameters
----------
parameters : torch.Tensor or list(torch.Tensors)
Arguments:
parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters.
key : str
key (`str`):
The hyperparamter to override.
value : object
The value for the hyperparamters.
key_value_dict : dict
value:
The hyperparameter values.
key_value_dict (`dict`):
A dictionary with multiple key-values to override.
Example:
```py
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32)
```
"""
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
......@@ -92,6 +113,17 @@ class GlobalOptimManager:
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
"""
Base 8-bit optimizer class.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
......@@ -125,11 +157,11 @@ class Optimizer8bit(torch.optim.Optimizer):
super().__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
"""Load an optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
Arguments:
state_dict (`dict`):
An optimizer state (should be returned from a call to `state_dict`) to load.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
......@@ -237,11 +269,11 @@ class Optimizer8bit(torch.optim.Optimizer):
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
"""Perform a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
closure (`Callable`, *optional*, defaults to `None`):
A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
......@@ -339,6 +371,39 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False,
is_paged=False
):
"""
Base 2-state update optimizer class.
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple`, defaults to (0.9, 0.999)):
The beta values for the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value for the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 0.0):
The maximum value to normalize each block with.
skip_zeros (`bool`, defaults to `False`):
Whether to skip zero values for sparse gradients and models to ensure correct updates.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
......@@ -552,6 +617,39 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False,
is_paged=False
):
"""
Base 1-state update optimizer class.
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple`, defaults to (0.9, 0.0)):
The beta values for the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value for the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 0.0):
The maximum value to normalize each block with.
skip_zeros (`bool`, defaults to `False`):
Whether to skip zero values for sparse gradients and models to ensure correct updates.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
......
......@@ -21,6 +21,35 @@ class RMSprop(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
Base RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
......@@ -57,6 +86,35 @@ class RMSprop8bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
8-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
......@@ -93,6 +151,35 @@ class RMSprop32bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
32-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0:
raise NotImplementedError(
......
......@@ -20,6 +20,33 @@ class SGD(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
Base SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
......@@ -51,6 +78,31 @@ class SGD8bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
8-bit SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
......@@ -82,6 +134,31 @@ class SGD32bit(Optimizer1State):
percentile_clipping=100,
block_wise=True,
):
"""
32-bit SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!")
super().__init__(
......
......@@ -26,5 +26,33 @@
title: Papers, resources & how to cite
- title: API reference
sections:
- local: quantization
- local: reference/quantization
title: Quantization
- title: Optimizers
sections:
- local: reference/optim/optim_overview
title: Overview
- local: reference/optim/adagrad
title: AdaGrad
- local: reference/optim/adam
title: Adam
- local: reference/optim/adamw
title: AdamW
- local: reference/optim/lamb
title: LAMB
- local: reference/optim/lars
title: LARS
- local: reference/optim/lion
title: Lion
- local: reference/optim/rmsprop
title: RMSprop
- local: reference/optim/sgd
title: SGD
- title: k-bit quantizers
sections:
- local: reference/nn/linear8bit
title: 8-bit quantizer
- local: reference/nn/linear4bit
title: 4-bit quantizer
- local: reference/nn/embeddings
title: Embedding
# Embedding
The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class.
The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization.
## Embedding
[[autodoc]] bitsandbytes.nn.Embedding
- __init__
## StableEmbedding
[[autodoc]] bitsandbytes.nn.StableEmbedding
- __init__
# 4-bit quantization
[QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is a quantization data type for normally distributed data and can improve performance.
## Linear4bit
[[autodoc]] bitsandbytes.nn.Linear4bit
- __init__
## LinearFP4
[[autdodoc]] bitsandbytes.nn.LinearFP4
- __init__
## LinearNF4
[[autodoc]] bitsandbytes.nn.LinearNF4
- __init__
## Params4bit
[[autodoc]] bitsandbytes.nn.Params4bit
- __init__
# 8-bit quantization
[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.
## Linear8bitLt
[[autodoc]] bitsandbytes.nn.Linear8bitLt
- __init__
## Int8Params
[[autodoc]] bitsandbytes.nn.Int8Params
- __init__
# AdaGrad
[AdaGrad (Adaptive Gradient)](https://jmlr.org/papers/v12/duchi11a.html) is an adaptive learning rate optimizer. AdaGrad stores a sum of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, eliminating the need to manually tune the learning rate.
## Adagrad[[api-class]]
[[autodoc]] bitsandbytes.optim.Adagrad
- __init__
## Adagrad8bit
[[autodoc]] bitsandbytes.optim.Adagrad8bit
- __init__
## Adagrad32bit
[[autodoc]] bitsandbytes.optim.Adagrad32bit
- __init__
# Adam
[Adam (Adaptive moment estimation)](https://hf.co/papers/1412.6980) is an adaptive learning rate optimizer, combining ideas from [`SGD`] with momentum and [`RMSprop`] to automatically scale the learning rate:
- a weighted average of the past gradients to provide direction (first-moment)
- a weighted average of the *squared* past gradients to adapt the learning rate to each parameter (second-moment)
bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted.
## Adam[[api-class]]
[[autodoc]] bitsandbytes.optim.Adam
- __init__
## Adam8bit
[[autodoc]] bitsandbytes.optim.Adam8bit
- __init__
## Adam32bit
[[autodoc]] bitsandbytes.optim.Adam32bit
- __init__
## PagedAdam
[[autodoc]] bitsandbytes.optim.PagedAdam
- __init__
## PagedAdam8bit
[[autodoc]] bitsandbytes.optim.PagedAdam8bit
- __init__
## PagedAdam32bit
[[autodoc]] bitsandbytes.optim.PagedAdam32bit
- __init__
# AdamW
[AdamW](https://hf.co/papers/1711.05101) is a variant of the [`Adam`] optimizer that separates weight decay from the gradient update based on the observation that the weight decay formulation is different when applied to [`SGD`] and [`Adam`].
bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted.
## AdamW[[api-class]]
[[autodoc]] bitsandbytes.optim.AdamW
- __init__
## AdamW8bit
[[autodoc]] bitsandbytes.optim.AdamW8bit
- __init__
## AdamW32bit
[[autodoc]] bitsandbytes.optim.AdamW32bit
- __init__
## PagedAdamW
[[autodoc]] bitsandbytes.optim.PagedAdamW
- __init__
## PagedAdamW8bit
[[autodoc]] bitsandbytes.optim.PagedAdamW8bit
- __init__
## PagedAdamW32bit
[[autodoc]] bitsandbytes.optim.PagedAdamW32bit
- __init__
# LAMB
[LAMB (Layerwise adaptive large batch optimization)](https://hf.co/papers/1904.00962) is an adaptive optimizer designed for training with large batch sizes to accelerate training, combining ideas from [`LARS`] and [`Adam`] to automatically scale the learning rate for each layer:
- calculates a *trust ratio* between the weight and gradient norm in a layer and clips the ratio to prevent overly large or small updates
- updates weights with the first and second-moments
## LAMB[[api-class]]
[[autodoc]] bitsandbytes.optim.LAMB
- __init__
## LAMB8bit
[[autodoc]] bitsandbytes.optim.LAMB8bit
- __init__
## LAMB32bit
[[autodoc]] bitsandbytes.optim.LAMB32bit
- __init__
# LARS
[LARS (Layer-wise Adaptive Rate Scaling)](https:/hf.co/papers/1708.03888) is an optimizer designed for training with large batch sizes to accelerate training. LARS uses a separate learning rate for each *layer* instead of each parameter. The learning rate is calculated from a *trust ratio* between the weight and gradient norm in a layer. This helps calibrate a stable update size.
## LARS[[api-class]]
[[autodoc]] bitsandbytes.optim.LARS
- __init__
## LARS8bit
[[autodoc]] bitsandbytes.optim.LARS8bit
- __init__
## LARS32bit
[[autodoc]] bitsandbytes.optim.LARS32bit
- __init__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment