Commit 4075a643 authored by Egor Krivov's avatar Egor Krivov
Browse files

Update to kernel registration

parent 223fea51
......@@ -610,7 +610,7 @@ str2optimizer8bit_blockwise = {
}
def optimizer_update_32bit(
def _optimizer_update_32bit_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
......@@ -763,3 +763,4 @@ def _optimizer_update_8bit_blockwise_impl(
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment