Commit e33ba1c0 authored by Egor Krivov's avatar Egor Krivov
Browse files

Added mutated args to the schema

parent 24d9139e
...@@ -352,7 +352,7 @@ if ipex_cpu or ipex_xpu: ...@@ -352,7 +352,7 @@ if ipex_cpu or ipex_xpu:
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_32bit", "bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, Tensor! unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()", "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
) )
...@@ -395,7 +395,7 @@ def _( ...@@ -395,7 +395,7 @@ def _(
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise", "bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()", "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment