Unverified Commit bf015381 authored by Titus's avatar Titus Committed by GitHub
Browse files

Merge pull request #1128 from akx/type-fixes

Minor type/doc fixes
parents b03ce0e0 0c6dda08
...@@ -1618,18 +1618,18 @@ def optimizer_update_8bit( ...@@ -1618,18 +1618,18 @@ def optimizer_update_8bit(
g: Tensor, g: Tensor,
p: Tensor, p: Tensor,
state1: Tensor, state1: Tensor,
state2: Tensor, state2: Optional[torch.Tensor],
beta1: float, beta1: float,
beta2: float, beta2: float,
eps: float, eps: float,
step: int, step: int,
lr: float, lr: float,
qmap1: Tensor, qmap1: Tensor,
qmap2: Tensor, qmap2: Optional[torch.Tensor],
max1: Tensor, max1: Tensor,
max2: Tensor, max2: Optional[torch.Tensor],
new_max1: Tensor, new_max1: Tensor,
new_max2: Tensor, new_max2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None, unorm_vec: Optional[torch.Tensor] = None,
...@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise( ...@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise(
g: Tensor, g: Tensor,
p: Tensor, p: Tensor,
state1: Tensor, state1: Tensor,
state2: Tensor, state2: Optional[torch.Tensor],
beta1: float, beta1: float,
beta2: float, beta2: float,
eps: float, eps: float,
step: int, step: int,
lr: float, lr: float,
qmap1: Tensor, qmap1: Tensor,
qmap2: Tensor, qmap2: Optional[torch.Tensor],
absmax1: Tensor, absmax1: Tensor,
absmax2: Tensor, absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
skip_zeros=False, skip_zeros=False,
......
...@@ -658,8 +658,8 @@ class Linear8bitLt(nn.Linear): ...@@ -658,8 +658,8 @@ class Linear8bitLt(nn.Linear):
def __init__( def __init__(
self, self,
input_features, input_features: int,
output_features, output_features: int,
bias=True, bias=True,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False, memory_efficient_backward=False,
...@@ -671,9 +671,9 @@ class Linear8bitLt(nn.Linear): ...@@ -671,9 +671,9 @@ class Linear8bitLt(nn.Linear):
Initialize Linear8bitLt class. Initialize Linear8bitLt class.
Args: Args:
input_features (`str`): input_features (`int`):
Number of input features of the linear layer. Number of input features of the linear layer.
output_features (`str`): output_features (`int`):
Number of output features of the linear layer. Number of output features of the linear layer.
bias (`bool`, defaults to `True`): bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
......
...@@ -38,8 +38,8 @@ class Adagrad(Optimizer1State): ...@@ -38,8 +38,8 @@ class Adagrad(Optimizer1State):
The epsilon value prevents division by zero in the optimizer. The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -105,8 +105,8 @@ class Adagrad8bit(Optimizer1State): ...@@ -105,8 +105,8 @@ class Adagrad8bit(Optimizer1State):
The epsilon value prevents division by zero in the optimizer. The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 8): optim_bits (`int`, defaults to 8):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -173,8 +173,8 @@ class Adagrad32bit(Optimizer1State): ...@@ -173,8 +173,8 @@ class Adagrad32bit(Optimizer1State):
The epsilon value prevents division by zero in the optimizer. The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -47,8 +47,8 @@ class Adam(Optimizer2State): ...@@ -47,8 +47,8 @@ class Adam(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -108,8 +108,8 @@ class Adam8bit(Optimizer2State): ...@@ -108,8 +108,8 @@ class Adam8bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -169,8 +169,8 @@ class Adam32bit(Optimizer2State): ...@@ -169,8 +169,8 @@ class Adam32bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -230,8 +230,8 @@ class PagedAdam(Optimizer2State): ...@@ -230,8 +230,8 @@ class PagedAdam(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -291,8 +291,8 @@ class PagedAdam8bit(Optimizer2State): ...@@ -291,8 +291,8 @@ class PagedAdam8bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -352,8 +352,8 @@ class PagedAdam32bit(Optimizer2State): ...@@ -352,8 +352,8 @@ class PagedAdam32bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -39,8 +39,8 @@ class AdamW(Optimizer2State): ...@@ -39,8 +39,8 @@ class AdamW(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -100,8 +100,8 @@ class AdamW8bit(Optimizer2State): ...@@ -100,8 +100,8 @@ class AdamW8bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -161,8 +161,8 @@ class AdamW32bit(Optimizer2State): ...@@ -161,8 +161,8 @@ class AdamW32bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -221,8 +221,8 @@ class PagedAdamW(Optimizer2State): ...@@ -221,8 +221,8 @@ class PagedAdamW(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -281,8 +281,8 @@ class PagedAdamW8bit(Optimizer2State): ...@@ -281,8 +281,8 @@ class PagedAdamW8bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -341,8 +341,8 @@ class PagedAdamW32bit(Optimizer2State): ...@@ -341,8 +341,8 @@ class PagedAdamW32bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -45,8 +45,8 @@ class LAMB(Optimizer2State): ...@@ -45,8 +45,8 @@ class LAMB(Optimizer2State):
Whether to use the AdamW variant. Whether to use the AdamW variant.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -109,8 +109,8 @@ class LAMB8bit(Optimizer2State): ...@@ -109,8 +109,8 @@ class LAMB8bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`): adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant. Whether to use the AdamW variant.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -173,8 +173,8 @@ class LAMB32bit(Optimizer2State): ...@@ -173,8 +173,8 @@ class LAMB32bit(Optimizer2State):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
adam_w_mode (`bool`, defaults to `True`): adam_w_mode (`bool`, defaults to `True`):
Whether to use the AdamW variant. Whether to use the AdamW variant.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -41,8 +41,8 @@ class LARS(Optimizer1State): ...@@ -41,8 +41,8 @@ class LARS(Optimizer1State):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -98,8 +98,8 @@ class LARS8bit(Optimizer1State): ...@@ -98,8 +98,8 @@ class LARS8bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`): nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -155,8 +155,8 @@ class LARS32bit(Optimizer1State): ...@@ -155,8 +155,8 @@ class LARS32bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`): nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -33,8 +33,8 @@ class Lion(Optimizer1State): ...@@ -33,8 +33,8 @@ class Lion(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -85,8 +85,8 @@ class Lion8bit(Optimizer1State): ...@@ -85,8 +85,8 @@ class Lion8bit(Optimizer1State):
The beta values are the decay rates of the first and second-order moment of the optimizer. The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0): weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer. The weight decay value for the optimizer.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -137,8 +137,8 @@ class Lion32bit(Optimizer1State): ...@@ -137,8 +137,8 @@ class Lion32bit(Optimizer1State):
The beta values are the decay rates of the first and second-order moment of the optimizer. The beta values are the decay rates of the first and second-order moment of the optimizer.
weight_decay (`float`, defaults to 0): weight_decay (`float`, defaults to 0):
The weight decay value for the optimizer. The weight decay value for the optimizer.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -191,8 +191,8 @@ class PagedLion(Optimizer1State): ...@@ -191,8 +191,8 @@ class PagedLion(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -242,8 +242,8 @@ class PagedLion8bit(Optimizer1State): ...@@ -242,8 +242,8 @@ class PagedLion8bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -293,8 +293,8 @@ class PagedLion32bit(Optimizer1State): ...@@ -293,8 +293,8 @@ class PagedLion32bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -373,8 +373,8 @@ class Optimizer2State(Optimizer8bit): ...@@ -373,8 +373,8 @@ class Optimizer2State(Optimizer8bit):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -596,8 +596,8 @@ class Optimizer1State(Optimizer8bit): ...@@ -596,8 +596,8 @@ class Optimizer1State(Optimizer8bit):
The weight decay value for the optimizer. The weight decay value for the optimizer.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -41,8 +41,8 @@ class RMSprop(Optimizer1State): ...@@ -41,8 +41,8 @@ class RMSprop(Optimizer1State):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -104,8 +104,8 @@ class RMSprop8bit(Optimizer1State): ...@@ -104,8 +104,8 @@ class RMSprop8bit(Optimizer1State):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -167,8 +167,8 @@ class RMSprop32bit(Optimizer1State): ...@@ -167,8 +167,8 @@ class RMSprop32bit(Optimizer1State):
Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute. Whether the gradients are normalized by the variance. If `True`, it can help training at the expense of additional compute.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -38,8 +38,8 @@ class SGD(Optimizer1State): ...@@ -38,8 +38,8 @@ class SGD(Optimizer1State):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -94,8 +94,8 @@ class SGD8bit(Optimizer1State): ...@@ -94,8 +94,8 @@ class SGD8bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`): nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
...@@ -150,8 +150,8 @@ class SGD32bit(Optimizer1State): ...@@ -150,8 +150,8 @@ class SGD32bit(Optimizer1State):
The weight decay value for the optimizer. The weight decay value for the optimizer.
nesterov (`bool`, defaults to `False`): nesterov (`bool`, defaults to `False`):
Whether to use Nesterov momentum. Whether to use Nesterov momentum.
args (`dict`, defaults to `None`): args (`object`, defaults to `None`):
A dictionary with additional arguments. An object with additional arguments.
min_8bit_size (`int`, defaults to 4096): min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization. The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100): percentile_clipping (`int`, defaults to 100):
......
...@@ -140,7 +140,7 @@ def replace_linear( ...@@ -140,7 +140,7 @@ def replace_linear(
List of modules names not to convert. Defaults to `lm_head`. List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`): copy_weights (`bool`):
Copy the weights from the old linear module to the new one Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`): post_processing_function (`str`):
A function name of the replacement linear class that is called A function name of the replacement linear class that is called
after processing. after processing.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment