Commit 0c6dda08 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Mark some optimizer update arguments as Noneable (they were being called with Nones)

parent 3ec3dd26
...@@ -1618,18 +1618,18 @@ def optimizer_update_8bit( ...@@ -1618,18 +1618,18 @@ def optimizer_update_8bit(
g: Tensor, g: Tensor,
p: Tensor, p: Tensor,
state1: Tensor, state1: Tensor,
state2: Tensor, state2: Optional[torch.Tensor],
beta1: float, beta1: float,
beta2: float, beta2: float,
eps: float, eps: float,
step: int, step: int,
lr: float, lr: float,
qmap1: Tensor, qmap1: Tensor,
qmap2: Tensor, qmap2: Optional[torch.Tensor],
max1: Tensor, max1: Tensor,
max2: Tensor, max2: Optional[torch.Tensor],
new_max1: Tensor, new_max1: Tensor,
new_max2: Tensor, new_max2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None, unorm_vec: Optional[torch.Tensor] = None,
...@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise( ...@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise(
g: Tensor, g: Tensor,
p: Tensor, p: Tensor,
state1: Tensor, state1: Tensor,
state2: Tensor, state2: Optional[torch.Tensor],
beta1: float, beta1: float,
beta2: float, beta2: float,
eps: float, eps: float,
step: int, step: int,
lr: float, lr: float,
qmap1: Tensor, qmap1: Tensor,
qmap2: Tensor, qmap2: Optional[torch.Tensor],
absmax1: Tensor, absmax1: Tensor,
absmax2: Tensor, absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float = 0.0,
gnorm_scale: float = 1.0, gnorm_scale: float = 1.0,
skip_zeros=False, skip_zeros=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment