"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "5b015890be1880f910219c9ea966054b8f98affc"
Commit b43edf56 authored by Egor Krivov's avatar Egor Krivov
Browse files

Add interface for 8bit optimizer

parent adc7fda7
...@@ -348,3 +348,64 @@ if ipex_cpu or ipex_xpu: ...@@ -348,3 +348,64 @@ if ipex_cpu or ipex_xpu:
) -> torch.Tensor: ) -> torch.Tensor:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device) return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor g, Tensor p, Tensor state1, Tensor! state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor qmap1, Tensor! qmap2, Tensor absmax1, Tensor! absmax2, float weight_decay, float gnorm_scale, bool skip_zeros) -> ()",
)
@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
torch._check(
state1.dtype == torch.uint8,
lambda: f"state1 must be uint8, got {state1.dtype}",
)
torch._check(
qmap1.dtype == absmax1.dtype == torch.float32,
lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
)
if state2 is not None:
torch._check(
state2.dtype == torch.uint8,
lambda: f"state2 must be uint8, got {state2.dtype}",
)
torch._check(
qmap2.dtype == absmax2.dtype == torch.float32,
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
)
...@@ -538,3 +538,133 @@ def _gemv_4bit_impl( ...@@ -538,3 +538,133 @@ def _gemv_4bit_impl(
ct.c_int32(blocksize), ct.c_int32(blocksize),
stream, stream,
) )
str2optimizer8bit_blockwise = {
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}
def _optimizer_update_8bit_blockwise_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.nsor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name)
if optimizer_fns is None:
raise ValueError(
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
)
if g.dtype == torch.float32:
optimizer_fn = optimizer_fns[0]
elif g.dtype == torch.float16:
optimizer_fn = optimizer_fns[1]
elif g.dtype == torch.bfloat16:
optimizer_fn = optimizer_fns[2]
else:
raise ValueError(
f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)
with _cuda_device_of(g):
optimizer_fn(
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
...@@ -82,39 +82,6 @@ str2optimizer8bit = { ...@@ -82,39 +82,6 @@ str2optimizer8bit = {
), ),
} }
str2optimizer8bit_blockwise = {
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}
class GlobalPageManager: class GlobalPageManager:
_instance = None _instance = None
...@@ -422,8 +389,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): ...@@ -422,8 +389,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for t in tensors: for t in tensors:
# NULL pointers and paged tensors are OK. # NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False): if t is not None and not getattr(t, "is_paged", False):
on_gpu &= t.is_cuda on_gpu &= t.device.type != "cpu"
gpu_ids.add(t.device.index) gpu_ids.add((t.device.type, t.device.index))
if not on_gpu: if not on_gpu:
raise RuntimeError( raise RuntimeError(
...@@ -1449,45 +1416,29 @@ def optimizer_update_8bit_blockwise( ...@@ -1449,45 +1416,29 @@ def optimizer_update_8bit_blockwise(
) -> None: ) -> None:
optim_func = None optim_func = None
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (
g.dtype == torch.bfloat16
and state1.dtype == torch.uint8
and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
):
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
with _cuda_device_of(g): torch.ops.bitsandbytes.optimizer_update_8bit_blockwise(
optim_func( optimizer_name,
get_ptr(p), g,
get_ptr(g), p,
get_ptr(state1), state1,
get_ptr(state2), state2,
ct.c_float(beta1), beta1,
ct.c_float(beta2), beta2,
ct.c_float(beta3), beta3,
ct.c_float(alpha), alpha,
ct.c_float(eps), eps,
ct.c_int32(step), step,
ct.c_float(lr), lr,
get_ptr(qmap1), qmap1,
get_ptr(qmap2), qmap2,
get_ptr(absmax1), absmax1,
get_ptr(absmax2), absmax2,
ct.c_float(weight_decay), weight_decay,
ct.c_float(gnorm_scale), gnorm_scale,
ct.c_bool(skip_zeros), skip_zeros,
ct.c_int32(g.numel()), )
)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
......
...@@ -10,6 +10,7 @@ from typing import Optional ...@@ -10,6 +10,7 @@ from typing import Optional
import torch import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
from bitsandbytes.utils import sync_gpu
class MockArgs: class MockArgs:
...@@ -289,11 +290,11 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -289,11 +290,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self.prefetch_state(p) self.prefetch_state(p)
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize() sync_gpu(p)
if self.is_paged: if self.is_paged:
# all paged operations are asynchronous, we need # all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
torch.cuda.synchronize() sync_gpu(loss)
return loss return loss
......
...@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data): ...@@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}
def sync_gpu(t: torch.Tensor):
if t.device.type == "cuda":
torch.cuda.synchronize()
elif t.device.type == "xpu":
torch.xpu.synchronize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment