@@ -87,7 +87,7 @@ class AdamW8bit(Optimizer2State):
8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
...
...
@@ -159,7 +159,7 @@ class AdamW32bit(Optimizer2State):
32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
...
...
@@ -219,7 +219,7 @@ class PagedAdamW(Optimizer2State):
Paged AdamW optimizer.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
...
...
@@ -241,8 +241,6 @@ class PagedAdamW(Optimizer2State):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__(
"adam",
...
...
@@ -279,7 +277,7 @@ class PagedAdamW8bit(Optimizer2State):
Paged 8-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
...
...
@@ -303,8 +301,6 @@ class PagedAdamW8bit(Optimizer2State):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
ifamsgrad:
...
...
@@ -350,7 +346,7 @@ class PagedAdamW32bit(Optimizer2State):
Paged 32-bit AdamW optimizer.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
...
...
@@ -372,8 +368,6 @@ class PagedAdamW32bit(Optimizer2State):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.