Commit 3b89a05e authored by Egor Krivov's avatar Egor Krivov
Browse files

Add 32bit optimizer interface

parent abf4a1e3
...@@ -350,6 +350,49 @@ if ipex_cpu or ipex_xpu: ...@@ -350,6 +350,49 @@ if ipex_cpu or ipex_xpu:
return torch.empty(shape, dtype=dtype, device=A.device) return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define(
"bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, Tensor! unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros) -> ()",
)
@register_fake("bitsandbytes::optimizer_update_32bit")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise", "bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()", "(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
......
...@@ -540,6 +540,42 @@ def _gemv_4bit_impl( ...@@ -540,6 +540,42 @@ def _gemv_4bit_impl(
) )
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"momentum": (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
),
"rmsprop": (
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
),
"lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
),
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}
str2optimizer8bit_blockwise = { str2optimizer8bit_blockwise = {
"adam": ( "adam": (
lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp32,
...@@ -574,6 +610,65 @@ str2optimizer8bit_blockwise = { ...@@ -574,6 +610,65 @@ str2optimizer8bit_blockwise = {
} }
def optimizer_update_32bit(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
optim_fns = str2optimizer32bit.get(optimizer_name, None)
if optim_fns is None:
raise ValueError(
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
)
if g.dtype == torch.float32:
optim_func = optim_fns[0]
elif g.dtype == torch.float16:
optim_func = optim_fns[1]
elif g.dtype == torch.bfloat16 and len(optim_fns) == 3:
optim_func = optim_fns[2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
with _cuda_device_of(g):
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
def _optimizer_update_8bit_blockwise_impl( def _optimizer_update_8bit_blockwise_impl(
optimizer_name: str, optimizer_name: str,
g: torch.Tensor, g: torch.Tensor,
......
...@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib ...@@ -20,41 +20,6 @@ from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap = {} name2qmap = {}
"""C FUNCTIONS FOR OPTIMIZERS""" """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"momentum": (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
),
"rmsprop": (
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
),
"lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
),
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}
str2optimizer8bit = { str2optimizer8bit = {
"adam": ( "adam": (
lib.cadam_static_8bit_grad_32, lib.cadam_static_8bit_grad_32,
...@@ -1219,41 +1184,27 @@ def optimizer_update_32bit( ...@@ -1219,41 +1184,27 @@ def optimizer_update_32bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
optim_func = None
if g.dtype == torch.float32:
optim_func = str2optimizer32bit[optimizer_name][0]
elif g.dtype == torch.float16:
optim_func = str2optimizer32bit[optimizer_name][1]
elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3:
optim_func = str2optimizer32bit[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
is_on_gpu([g, p, state1, state2, unorm_vec]) is_on_gpu([g, p, state1, state2, unorm_vec])
torch.ops.bitsandbytes.optimizer_update_32bit(
with _cuda_device_of(g): optimizer_name,
optim_func( g,
get_ptr(g), p,
get_ptr(p), state1,
get_ptr(state1), state2,
get_ptr(state2), unorm_vec,
get_ptr(unorm_vec), max_unorm,
ct.c_float(max_unorm), param_norm,
ct.c_float(param_norm), beta1,
ct.c_float(beta1), beta2,
ct.c_float(beta2), beta3,
ct.c_float(beta3), alpha,
ct.c_float(alpha), eps,
ct.c_float(eps), weight_decay,
ct.c_float(weight_decay), step,
ct.c_int32(step), lr,
ct.c_float(lr), gnorm_scale,
ct.c_float(gnorm_scale), skip_zeros,
ct.c_bool(skip_zeros), )
ct.c_int32(g.numel()),
)
@deprecated( @deprecated(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment