Unverified Commit ac5d6ee6 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] implement API docs (#1075)



* optims

* fix path

* fix path

* mdx

* fix path

* toctree

* fix

* optimizer, adagrad

* add init

* add

* more apis

* params

* clarify

* run pre-commit hooks

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent 87e029bc
...@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 ...@@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848
# Remove f-prefix from strings that don't use formatting # Remove f-prefix from strings that don't use formatting
7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6
# format tests/linear_4bit.py
34735ba89de8235ea9da6ef409f814dcea9e2038
...@@ -21,16 +21,7 @@ T = TypeVar("T", bound="torch.nn.Module") ...@@ -21,16 +21,7 @@ T = TypeVar("T", bound="torch.nn.Module")
class StableEmbedding(torch.nn.Embedding): class StableEmbedding(torch.nn.Embedding):
""" """
Custom embedding layer designed for stable training in NLP tasks. The stable Custom embedding layer designed to improve stability during training for NLP tasks by using 32-bit optimizer states. It is designed to reduce gradient variations that can result from quantization. This embedding layer is initialized with Xavier uniform initialization followed by layer normalization.
embedding layer improves stability during optimization for models with word
embeddings, addressing issues related to the non-uniform distribution of input
tokens.
This stable embedding layer is initialized with Xavier uniform initialization,
followed by layer normalization. It is designed to support aggressive quantization,
addressing extreme gradient variations in non-uniform input distributions. The
stability of training is enhanced by using 32-bit optimizer states specifically
for this layer.
Example: Example:
...@@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -47,14 +38,11 @@ class StableEmbedding(torch.nn.Embedding):
``` ```
Attributes: Attributes:
norm (torch.nn.LayerNorm): Layer normalization applied after the embedding. norm (`torch.nn.LayerNorm`): Layer normalization applied after the embedding.
Methods: Methods:
reset_parameters(): Reset embedding parameters using Xavier uniform initialization. reset_parameters(): Reset embedding parameters using Xavier uniform initialization.
forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer. forward(input: Tensor) -> Tensor: Forward pass through the stable embedding layer.
Reference:
- [8-bit optimizer paper](https://arxiv.org/pdf/2110.02861.pdf)
""" """
def __init__( def __init__(
self, self,
...@@ -71,14 +59,22 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -71,14 +59,22 @@ class StableEmbedding(torch.nn.Embedding):
) -> None: ) -> None:
""" """
Args: Args:
num_embeddings (`int`): The number of unique embeddings (vocabulary size). num_embeddings (`int`):
embedding_dim (`int`): The dimensionality of the embedding. The number of unique embeddings (vocabulary size).
padding_idx (`Optional[int]`): If specified, pads the output with zeros at the given index. embedding_dim (`int`):
max_norm (`Optional[float]`): If given, renormalizes embeddings to have a maximum L2 norm. The dimensionality of the embedding.
norm_type (`float`, defaults to `2.0`): The p-norm to compute for the max_norm option. padding_idx (`Optional[int]`):
scale_grad_by_freq (`bool`): Scale gradient by frequency during backpropagation. Pads the output with zeros at the given index.
sparse (`bool`): If True, computes sparse gradients; False, computes dense gradients. max_norm (`Optional[float]`):
_weight (`Optional[Tensor]`): Pre-trained embeddings. Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
""" """
super().__init__( super().__init__(
num_embeddings, num_embeddings,
...@@ -131,6 +127,9 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -131,6 +127,9 @@ class StableEmbedding(torch.nn.Embedding):
class Embedding(torch.nn.Embedding): class Embedding(torch.nn.Embedding):
"""
Embedding class to store and retrieve word embeddings from their indices.
"""
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -143,6 +142,25 @@ class Embedding(torch.nn.Embedding): ...@@ -143,6 +142,25 @@ class Embedding(torch.nn.Embedding):
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
device: Optional[device] = None, device: Optional[device] = None,
) -> None: ) -> None:
"""
Args:
num_embeddings (`int`):
The number of unique embeddings (vocabulary size).
embedding_dim (`int`):
The dimensionality of the embedding.
padding_idx (`Optional[int]`):
Pads the output with zeros at the given index.
max_norm (`Optional[float]`):
Renormalizes embeddings to have a maximum L2 norm.
norm_type (`float`, defaults to `2.0`):
The p-norm to compute for the `max_norm` option.
scale_grad_by_freq (`bool`, defaults to `False`):
Scale gradient by frequency during backpropagation.
sparse (`bool`, defaults to `False`):
Computes dense gradients. Set to `True` to compute sparse gradients instead.
_weight (`Optional[Tensor]`):
Pretrained embeddings.
"""
super().__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
...@@ -416,7 +434,19 @@ class Linear4bit(nn.Linear): ...@@ -416,7 +434,19 @@ class Linear4bit(nn.Linear):
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
"""
Implements the FP4 data type.
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)
...@@ -432,6 +462,15 @@ class LinearNF4(Linear4bit): ...@@ -432,6 +462,15 @@ class LinearNF4(Linear4bit):
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236. the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
''' '''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None): def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
"""
Args:
input_features (`str`):
Number of input features of the linear layer.
output_features (`str`):
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device) super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)
......
...@@ -20,6 +20,33 @@ class Adagrad(Optimizer1State): ...@@ -20,6 +20,33 @@ class Adagrad(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
Base Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
...@@ -62,6 +89,33 @@ class Adagrad8bit(Optimizer1State): ...@@ -62,6 +89,33 @@ class Adagrad8bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
8-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 8):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
...@@ -105,6 +159,33 @@ class Adagrad32bit(Optimizer1State): ...@@ -105,6 +159,33 @@ class Adagrad32bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
32-bit Adagrad optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
lr_decay (`int`, defaults to 0):
The learning rate decay.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
......
...@@ -16,31 +16,205 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -16,31 +16,205 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdam(Optimizer2State): class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam8bit(Optimizer2State): class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam32bit(Optimizer2State): class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged 32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
......
...@@ -8,30 +8,204 @@ from bitsandbytes.optim.optimizer import Optimizer2State ...@@ -8,30 +8,204 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdamW(Optimizer2State): class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW8bit(Optimizer2State): class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW32bit(Optimizer2State): class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
...@@ -23,6 +23,39 @@ class LAMB(Optimizer2State): ...@@ -23,6 +23,39 @@ class LAMB(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
"""
Base LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__( super().__init__(
"lamb", "lamb",
params, params,
...@@ -56,6 +89,37 @@ class LAMB8bit(Optimizer2State): ...@@ -56,6 +89,37 @@ class LAMB8bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
"""
8-bit LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__( super().__init__(
"lamb", "lamb",
params, params,
...@@ -89,6 +153,37 @@ class LAMB32bit(Optimizer2State): ...@@ -89,6 +153,37 @@ class LAMB32bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
"""
32-bit LAMB optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
bias_correction (`bool`, defaults to `True`):
Whether to apply bias correction to the first and second-order moments.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 1.0):
The maximum gradient norm.
"""
super().__init__( super().__init__(
"lamb", "lamb",
params, params,
......
...@@ -23,6 +23,33 @@ class LARS(Optimizer1State): ...@@ -23,6 +23,33 @@ class LARS(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
max_unorm=0.02, max_unorm=0.02,
): ):
"""
Base LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
"LARS without momentum is not supported!" "LARS without momentum is not supported!"
...@@ -57,6 +84,31 @@ class LARS8bit(Optimizer1State): ...@@ -57,6 +84,31 @@ class LARS8bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
max_unorm=0.02, max_unorm=0.02,
): ):
"""
8-bit LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
"LARS without momentum is not supported!" "LARS without momentum is not supported!"
...@@ -91,6 +143,31 @@ class LARS32bit(Optimizer1State): ...@@ -91,6 +143,31 @@ class LARS32bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
max_unorm=0.02, max_unorm=0.02,
): ):
"""
32-bit LARS optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 1e-2):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
max_unorm (`float`, defaults to 0.02):
The maximum gradient norm.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError( raise NotImplementedError(
"LARS without momentum is not supported!" "LARS without momentum is not supported!"
......
...@@ -7,25 +7,165 @@ from bitsandbytes.optim.optimizer import Optimizer1State ...@@ -7,25 +7,165 @@ from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedLion(Optimizer1State): class PagedLion(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedLion8bit(Optimizer1State): class PagedLion8bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 8-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedLion32bit(Optimizer1State): class PagedLion32bit(Optimizer1State):
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
"""
Paged 32-bit Lion optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-4):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
...@@ -18,6 +18,9 @@ class MockArgs: ...@@ -18,6 +18,9 @@ class MockArgs:
class GlobalOptimManager: class GlobalOptimManager:
"""
A global optimizer manager for enabling custom optimizer configs.
"""
_instance = None _instance = None
def __init__(self): def __init__(self):
...@@ -53,22 +56,40 @@ class GlobalOptimManager: ...@@ -53,22 +56,40 @@ class GlobalOptimManager:
self, parameters, key=None, value=None, key_value_dict=None self, parameters, key=None, value=None, key_value_dict=None
): ):
""" """
Overrides initial optimizer config for specific parameters. Override initial optimizer config with specific hyperparameters.
The key-values of the optimizer config for the input parameters are overridden The key-values of the optimizer config for the input parameters are overridden
This can be both, optimizer parameters like "betas", or "lr" or it can be This can be both, optimizer parameters like `betas` or `lr`, or it can be
8-bit specific parameters like "optim_bits", "percentile_clipping". 8-bit specific parameters like `optim_bits` or `percentile_clipping`.
Parameters Arguments:
---------- parameters (`torch.Tensor` or `list(torch.Tensors)`):
parameters : torch.Tensor or list(torch.Tensors) The input parameters.
The input parameters. key (`str`):
key : str The hyperparamter to override.
The hyperparamter to override. value:
value : object The hyperparameter values.
The value for the hyperparamters. key_value_dict (`dict`):
key_value_dict : dict A dictionary with multiple key-values to override.
A dictionary with multiple key-values to override.
Example:
```py
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32)
```
""" """
self.uses_config_override = True self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter): if isinstance(parameters, torch.nn.Parameter):
...@@ -92,6 +113,17 @@ class GlobalOptimManager: ...@@ -92,6 +113,17 @@ class GlobalOptimManager:
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32, is_paged=False): def __init__(self, params, defaults, optim_bits=32, is_paged=False):
"""
Base 8-bit optimizer class.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
...@@ -125,11 +157,11 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -125,11 +157,11 @@ class Optimizer8bit(torch.optim.Optimizer):
super().__setstate__(state) super().__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the optimizer state. """Load an optimizer state.
Args: Arguments:
state_dict (dict): optimizer state. Should be an object returned state_dict (`dict`):
from a call to :meth:`state_dict`. An optimizer state (should be returned from a call to `state_dict`) to load.
""" """
# deepcopy, to be consistent with module API # deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict) state_dict = deepcopy(state_dict)
...@@ -237,11 +269,11 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -237,11 +269,11 @@ class Optimizer8bit(torch.optim.Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Perform a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (`Callable`, *optional*, defaults to `None`):
and returns the loss. A closure that reevaluates the model and returns the loss.
""" """
loss = None loss = None
if closure is not None: if closure is not None:
...@@ -339,6 +371,39 @@ class Optimizer2State(Optimizer8bit): ...@@ -339,6 +371,39 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False
): ):
"""
Base 2-state update optimizer class.
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple`, defaults to (0.9, 0.999)):
The beta values for the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value for the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 0.0):
The maximum value to normalize each block with.
skip_zeros (`bool`, defaults to `False`):
Whether to skip zero values for sparse gradients and models to ensure correct updates.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
...@@ -552,6 +617,39 @@ class Optimizer1State(Optimizer8bit): ...@@ -552,6 +617,39 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
is_paged=False is_paged=False
): ):
"""
Base 1-state update optimizer class.
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple`, defaults to (0.9, 0.0)):
The beta values for the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value for the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
max_unorm (`float`, defaults to 0.0):
The maximum value to normalize each block with.
skip_zeros (`bool`, defaults to `False`):
Whether to skip zero values for sparse gradients and models to ensure correct updates.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
......
...@@ -21,6 +21,35 @@ class RMSprop(Optimizer1State): ...@@ -21,6 +21,35 @@ class RMSprop(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
Base RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
...@@ -57,6 +86,35 @@ class RMSprop8bit(Optimizer1State): ...@@ -57,6 +86,35 @@ class RMSprop8bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
8-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!" "RMSprop with alpha==0.0 is not supported!"
...@@ -93,6 +151,35 @@ class RMSprop32bit(Optimizer1State): ...@@ -93,6 +151,35 @@ class RMSprop32bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
32-bit RMSprop optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-2):
The learning rate.
alpha (`float`, defaults to 0.99):
The alpha value is the decay rate of the squared gradients of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
centered (`bool`, defaults to `False`):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if alpha == 0: if alpha == 0:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -20,6 +20,33 @@ class SGD(Optimizer1State): ...@@ -20,6 +20,33 @@ class SGD(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
Base SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super().__init__( super().__init__(
...@@ -51,6 +78,31 @@ class SGD8bit(Optimizer1State): ...@@ -51,6 +78,31 @@ class SGD8bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
8-bit SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super().__init__( super().__init__(
...@@ -82,6 +134,31 @@ class SGD32bit(Optimizer1State): ...@@ -82,6 +134,31 @@ class SGD32bit(Optimizer1State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
"""
32-bit SGD optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`):
The learning rate.
momentum (`float`, defaults to 0):
The momentum value speeds up the optimizer by taking bigger steps.
dampening (`float`, defaults to 0):
The dampening value reduces the momentum of the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
if momentum == 0: if momentum == 0:
raise NotImplementedError("SGD without momentum is not supported!") raise NotImplementedError("SGD without momentum is not supported!")
super().__init__( super().__init__(
......
...@@ -26,5 +26,33 @@ ...@@ -26,5 +26,33 @@
title: Papers, resources & how to cite title: Papers, resources & how to cite
- title: API reference - title: API reference
sections: sections:
- local: quantization - local: reference/quantization
title: Quantization title: Quantization
- title: Optimizers
sections:
- local: reference/optim/optim_overview
title: Overview
- local: reference/optim/adagrad
title: AdaGrad
- local: reference/optim/adam
title: Adam
- local: reference/optim/adamw
title: AdamW
- local: reference/optim/lamb
title: LAMB
- local: reference/optim/lars
title: LARS
- local: reference/optim/lion
title: Lion
- local: reference/optim/rmsprop
title: RMSprop
- local: reference/optim/sgd
title: SGD
- title: k-bit quantizers
sections:
- local: reference/nn/linear8bit
title: 8-bit quantizer
- local: reference/nn/linear4bit
title: 4-bit quantizer
- local: reference/nn/embeddings
title: Embedding
# Embedding
The embedding class is used to store and retrieve word embeddings from their indices. There are two types of embeddings in bitsandbytes, the standard PyTorch [`Embedding`] class and the [`StableEmbedding`] class.
The [`StableEmbedding`] class was introduced in the [8-bit Optimizers via Block-wise Quantization](https://hf.co/papers/2110.02861) paper to reduce gradient variance as a result of the non-uniform distribution of input tokens. This class is designed to support quantization.
## Embedding
[[autodoc]] bitsandbytes.nn.Embedding
- __init__
## StableEmbedding
[[autodoc]] bitsandbytes.nn.StableEmbedding
- __init__
# 4-bit quantization
[QLoRA](https://hf.co/papers/2305.14314) is a finetuning method that quantizes a model to 4-bits and adds a set of low-rank adaptation (LoRA) weights to the model and tuning them through the quantized weights. This method also introduces a new data type, 4-bit NormalFloat (`LinearNF4`) in addition to the standard Float4 data type (`LinearFP4`). `LinearNF4` is a quantization data type for normally distributed data and can improve performance.
## Linear4bit
[[autodoc]] bitsandbytes.nn.Linear4bit
- __init__
## LinearFP4
[[autdodoc]] bitsandbytes.nn.LinearFP4
- __init__
## LinearNF4
[[autodoc]] bitsandbytes.nn.LinearNF4
- __init__
## Params4bit
[[autodoc]] bitsandbytes.nn.Params4bit
- __init__
# 8-bit quantization
[LLM.int8()](https://hf.co/papers/2208.07339) is a quantization method that doesn't degrade performance which makes large model inference more accessible. The key is to extract the outliers from the inputs and weights and multiply them in 16-bit. All other values are multiplied in 8-bit and quantized to Int8 before being dequantized back to 16-bits. The outputs from the 16-bit and 8-bit multiplication are combined to produce the final output.
## Linear8bitLt
[[autodoc]] bitsandbytes.nn.Linear8bitLt
- __init__
## Int8Params
[[autodoc]] bitsandbytes.nn.Int8Params
- __init__
# AdaGrad
[AdaGrad (Adaptive Gradient)](https://jmlr.org/papers/v12/duchi11a.html) is an adaptive learning rate optimizer. AdaGrad stores a sum of the squared past gradients for each parameter and uses it to scale their learning rate. This allows the learning rate to be automatically lower or higher depending on the magnitude of the gradient, eliminating the need to manually tune the learning rate.
## Adagrad[[api-class]]
[[autodoc]] bitsandbytes.optim.Adagrad
- __init__
## Adagrad8bit
[[autodoc]] bitsandbytes.optim.Adagrad8bit
- __init__
## Adagrad32bit
[[autodoc]] bitsandbytes.optim.Adagrad32bit
- __init__
# Adam
[Adam (Adaptive moment estimation)](https://hf.co/papers/1412.6980) is an adaptive learning rate optimizer, combining ideas from [`SGD`] with momentum and [`RMSprop`] to automatically scale the learning rate:
- a weighted average of the past gradients to provide direction (first-moment)
- a weighted average of the *squared* past gradients to adapt the learning rate to each parameter (second-moment)
bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted.
## Adam[[api-class]]
[[autodoc]] bitsandbytes.optim.Adam
- __init__
## Adam8bit
[[autodoc]] bitsandbytes.optim.Adam8bit
- __init__
## Adam32bit
[[autodoc]] bitsandbytes.optim.Adam32bit
- __init__
## PagedAdam
[[autodoc]] bitsandbytes.optim.PagedAdam
- __init__
## PagedAdam8bit
[[autodoc]] bitsandbytes.optim.PagedAdam8bit
- __init__
## PagedAdam32bit
[[autodoc]] bitsandbytes.optim.PagedAdam32bit
- __init__
# AdamW
[AdamW](https://hf.co/papers/1711.05101) is a variant of the [`Adam`] optimizer that separates weight decay from the gradient update based on the observation that the weight decay formulation is different when applied to [`SGD`] and [`Adam`].
bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted.
## AdamW[[api-class]]
[[autodoc]] bitsandbytes.optim.AdamW
- __init__
## AdamW8bit
[[autodoc]] bitsandbytes.optim.AdamW8bit
- __init__
## AdamW32bit
[[autodoc]] bitsandbytes.optim.AdamW32bit
- __init__
## PagedAdamW
[[autodoc]] bitsandbytes.optim.PagedAdamW
- __init__
## PagedAdamW8bit
[[autodoc]] bitsandbytes.optim.PagedAdamW8bit
- __init__
## PagedAdamW32bit
[[autodoc]] bitsandbytes.optim.PagedAdamW32bit
- __init__
# LAMB
[LAMB (Layerwise adaptive large batch optimization)](https://hf.co/papers/1904.00962) is an adaptive optimizer designed for training with large batch sizes to accelerate training, combining ideas from [`LARS`] and [`Adam`] to automatically scale the learning rate for each layer:
- calculates a *trust ratio* between the weight and gradient norm in a layer and clips the ratio to prevent overly large or small updates
- updates weights with the first and second-moments
## LAMB[[api-class]]
[[autodoc]] bitsandbytes.optim.LAMB
- __init__
## LAMB8bit
[[autodoc]] bitsandbytes.optim.LAMB8bit
- __init__
## LAMB32bit
[[autodoc]] bitsandbytes.optim.LAMB32bit
- __init__
# LARS
[LARS (Layer-wise Adaptive Rate Scaling)](https:/hf.co/papers/1708.03888) is an optimizer designed for training with large batch sizes to accelerate training. LARS uses a separate learning rate for each *layer* instead of each parameter. The learning rate is calculated from a *trust ratio* between the weight and gradient norm in a layer. This helps calibrate a stable update size.
## LARS[[api-class]]
[[autodoc]] bitsandbytes.optim.LARS
- __init__
## LARS8bit
[[autodoc]] bitsandbytes.optim.LARS8bit
- __init__
## LARS32bit
[[autodoc]] bitsandbytes.optim.LARS32bit
- __init__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment