Commit 0c6dda08 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Mark some optimizer update arguments as Noneable (they were being called with Nones)

parent 3ec3dd26
......@@ -1618,18 +1618,18 @@ def optimizer_update_8bit(
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Tensor,
qmap2: Optional[torch.Tensor],
max1: Tensor,
max2: Tensor,
max2: Optional[torch.Tensor],
new_max1: Tensor,
new_max2: Tensor,
new_max2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
......@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise(
g: Tensor,
p: Tensor,
state1: Tensor,
state2: Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: Tensor,
qmap2: Tensor,
qmap2: Optional[torch.Tensor],
absmax1: Tensor,
absmax2: Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment