Commit 0f6fe6bf authored by Egor Krivov's avatar Egor Krivov
Browse files

Fixed default args

parent e33ba1c0
...@@ -352,7 +352,7 @@ if ipex_cpu or ipex_xpu: ...@@ -352,7 +352,7 @@ if ipex_cpu or ipex_xpu:
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_32bit", "bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()", "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
) )
...@@ -395,7 +395,7 @@ def _( ...@@ -395,7 +395,7 @@ def _(
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise", "bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()", "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
) )
...@@ -417,8 +417,8 @@ def _( ...@@ -417,8 +417,8 @@ def _(
qmap2: Optional[torch.Tensor], qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor, absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor], absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float,
gnorm_scale: float = 1.0, gnorm_scale: float,
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
torch._check( torch._check(
......
...@@ -686,8 +686,8 @@ def _optimizer_update_8bit_blockwise_impl( ...@@ -686,8 +686,8 @@ def _optimizer_update_8bit_blockwise_impl(
qmap2: Optional[torch.Tensor], qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor, absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor], absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0, weight_decay: float,
gnorm_scale: float = 1.0, gnorm_scale: float,
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
# torch._check( # torch._check(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment