Unverified Commit d9645465 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Add AdEMAMix optimizer (#1360)

* Add AdEMAMix optimizer

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* AdEMAMix: add support for alpha/beta3 scheduling

* Update paged AdEMAMix
parent 8fc78924
......@@ -53,6 +53,11 @@ if lib and lib.compiled_with_cuda:
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}
str2optimizer8bit = {
......@@ -105,6 +110,11 @@ if lib and lib.compiled_with_cuda:
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}
......@@ -1550,6 +1560,8 @@ def optimizer_update_32bit(
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
beta3: float = 0.0,
alpha: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
......@@ -1585,6 +1597,10 @@ def optimizer_update_32bit(
Optimizer state 2.
beta2 : float
Optimizer beta2.
beta3 : float
Optimizer beta3.
alpha : float
Optimizer alpha.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
......@@ -1623,6 +1639,8 @@ def optimizer_update_32bit(
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
......@@ -1775,6 +1793,8 @@ def optimizer_update_8bit_blockwise(
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
......@@ -1815,6 +1835,8 @@ def optimizer_update_8bit_blockwise(
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
......
......@@ -13,6 +13,7 @@ from .adamw import (
PagedAdamW8bit,
PagedAdamW32bit,
)
from .ademamix import AdEMAMix, AdEMAMix8bit, AdEMAMix32bit, PagedAdEMAMix, PagedAdEMAMix8bit, PagedAdEMAMix32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
......
import math
from typing import Iterable, Literal, Optional, Tuple
import torch
import bitsandbytes.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State
class _ReferenceAdEMAMix(torch.optim.Optimizer):
"""
Reference: https://hf.co/papers/2409.03137
"""
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 1e-2, # default 0.0 or 1e-2?
t_beta3: Optional[int] = None,
t_alpha: Optional[int] = None,
):
defaults = dict(
lr=lr, betas=betas, alpha=alpha, eps=eps, weight_decay=weight_decay, t_beta3=t_beta3, t_alpha=t_alpha
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
if "step" in group:
group["step"] += 1
else:
group["step"] = 1
lr = group["lr"]
eps = group["eps"]
beta1, beta2, beta3 = group["betas"]
alpha = group["alpha"]
t_alpha = group["t_alpha"]
t_beta3 = group["t_beta3"]
weight_decay = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# For parity with bnb implementation we combine both fast
# and slow EMA stats into one stacked tensor.
state["m1_m2"] = p.new_zeros((2, *p.size()))
state["nu"] = torch.zeros_like(p) # second moment estimate
m1, m2, nu = state["m1_m2"][0], state["m1_m2"][1], state["nu"]
bias_correction1 = 1 - beta1 ** group["step"]
bias_correction2 = 1 - beta2 ** group["step"]
# Apply scheduler for alpha
if t_alpha is not None:
alpha = min(group["step"] * alpha / t_alpha, alpha)
# Apply scheduler for beta3
if t_beta3 is not None:
ln_beta1 = math.log(beta1)
ln_beta3 = math.log(beta3)
step_scale = group["step"] / t_beta3
beta3 = min(
math.exp((ln_beta1 * ln_beta3) / (((1 - step_scale) * ln_beta3) + (step_scale * ln_beta1))),
beta3,
)
# Update the EMAs
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Compute step
denom = (nu.sqrt() / (bias_correction2**0.5)).add(eps)
update = (m1.div(bias_correction1) + alpha * m2) / denom
# Add weight decay
update.add_(p, alpha=weight_decay)
# Apply update scaled by learning rate
p.add_(-lr * update)
return loss
class AdEMAMix(Optimizer2State):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
optim_bits: Literal[8, 32] = 32,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
"ademamix",
params=params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
optim_bits=optim_bits,
args=None,
min_8bit_size=min_8bit_size,
percentile_clipping=100,
block_wise=True,
is_paged=is_paged,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
)
@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
# In our AdEMAMix implementation, we use `state` to hold
# both the fast and slow EMAs. Here we override the base
# `Optimizer2State` to allocate a buffer twice as large.
# Additional consideration: we do not support block_wise=False,
# percentile clipping, or max_unorm.
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
dtype = torch.float32
elif config["optim_bits"] == 8:
dtype = torch.uint8
else:
raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p]
state["step"] = 0
if dtype == torch.uint8:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)
n = p.numel()
blocks = (n // 2048) + bool(n % 2048)
state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["state1"] = self._get_state_double_buffer(p, dtype=dtype)
state["state2"] = self.get_state_buffer(p, dtype=dtype)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["t_alpha"] is None and config["t_beta3"] is None:
# Not using alpha/beta3 scheduler; we can fall through.
super().update_step(group, p, gindex, pindex)
return
# Ensure contiguous memory layout
p.data = p.data.contiguous()
p.grad = p.grad.contiguous()
state = self.state[p]
grad = p.grad
state["step"] += 1
step = state["step"]
beta1, beta2, beta3 = config["betas"]
alpha = config["alpha"]
t_alpha = config["t_alpha"]
t_beta3 = config["t_beta3"]
# Apply scheduler for alpha
if t_alpha is not None:
alpha_t = min(step * alpha / t_alpha, alpha)
else:
alpha_t = alpha
# Apply scheduler for beta3
if t_beta3 is not None:
ln_beta1 = math.log(beta1)
ln_beta3 = math.log(beta3)
step_scale = step / t_beta3
beta3_t = min(
math.exp((ln_beta1 * ln_beta3) / (((1 - step_scale) * ln_beta3) + (step_scale * ln_beta1))), beta3
)
else:
beta3_t = beta3
# Apply updates
if state["state1"].dtype == torch.float32:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
beta1,
config["eps"],
step,
config["lr"],
state["state2"],
beta2,
beta3_t,
alpha_t,
config["weight_decay"],
gnorm_scale=1.0,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
beta3_t,
alpha_t,
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["absmax1"],
state["absmax2"],
config["weight_decay"],
gnorm_scale=1.0,
skip_zeros=config["skip_zeros"],
)
def _get_state_double_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
return torch.zeros((2, *p.size()), dtype=dtype, device=p.device)
else:
buff = F.get_paged(*(2, *p.size()), dtype=dtype, device=p.device)
F.fill(buff, 0)
self.page_mng.paged_tensors.append(buff)
return buff
class AdEMAMix8bit(AdEMAMix):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
optim_bits=8,
min_8bit_size=min_8bit_size,
is_paged=is_paged,
)
class PagedAdEMAMix8bit(AdEMAMix8bit):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
min_8bit_size=min_8bit_size,
is_paged=True,
)
class PagedAdEMAMix(AdEMAMix):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
optim_bits: Literal[8, 32] = 32,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
optim_bits=optim_bits,
min_8bit_size=min_8bit_size,
is_paged=True,
)
class AdEMAMix32bit(Optimizer2State):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
is_paged: bool = False,
):
super().__init__(
"ademamix",
params=params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
optim_bits=32,
args=None,
min_8bit_size=min_8bit_size,
percentile_clipping=100,
block_wise=True,
is_paged=is_paged,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
)
class PagedAdEMAMix32bit(AdEMAMix32bit):
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
eps: float = 1e-8,
weight_decay: float = 1e-2,
min_8bit_size: int = 4096,
):
super().__init__(
params,
lr=lr,
betas=betas,
alpha=alpha,
t_alpha=t_alpha,
t_beta3=t_beta3,
eps=eps,
weight_decay=weight_decay,
min_8bit_size=min_8bit_size,
is_paged=True,
)
......@@ -5,6 +5,7 @@
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from itertools import chain
from typing import Optional
import torch
......@@ -172,7 +173,7 @@ class Optimizer8bit(torch.optim.Optimizer):
raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)):
raise ValueError(
"loaded state dict contains a parameter group that doesn't match the size of optimizer's group",
)
......@@ -183,6 +184,7 @@ class Optimizer8bit(torch.optim.Optimizer):
for old_id, p in zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
strict=True,
)
}
......@@ -224,7 +226,7 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
......@@ -302,6 +304,9 @@ class Optimizer8bit(torch.optim.Optimizer):
config["eps"] = group["eps"]
config["weight_decay"] = group["weight_decay"]
config["lr"] = group["lr"]
config["alpha"] = group.get("alpha")
config["t_alpha"] = group.get("t_alpha")
config["t_beta3"] = group.get("t_beta3")
config["optim_bits"] = self.args.optim_bits
config["min_8bit_size"] = self.args.min_8bit_size
config["percentile_clipping"] = self.args.percentile_clipping
......@@ -357,6 +362,9 @@ class Optimizer2State(Optimizer8bit):
max_unorm=0.0,
skip_zeros=False,
is_paged=False,
alpha=0.0,
t_alpha: Optional[int] = None,
t_beta3: Optional[int] = None,
):
"""
Base 2-state update optimizer class.
......@@ -390,6 +398,13 @@ class Optimizer2State(Optimizer8bit):
Whether to skip zero values for sparse gradients and models to ensure correct updates.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
alpha (`float`, defaults to 0.0):
The alpha value for the AdEMAMix optimizer.
t_alpha (`Optional[int]`, defaults to `None`):
Number of iterations for alpha scheduling with AdEMAMix.
t_beta3 (`Optional[int]`, defaults to `None`):
Number of iterations for beta scheduling with AdEMAMix.
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
......@@ -404,7 +419,11 @@ class Optimizer2State(Optimizer8bit):
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, alpha=alpha, t_alpha=t_alpha, t_beta3=t_beta3
)
super().__init__(params, defaults, optim_bits, is_paged)
if args is None:
......@@ -511,6 +530,8 @@ class Optimizer2State(Optimizer8bit):
config["lr"],
state["state2"],
config["betas"][1],
config["betas"][2] if len(config["betas"]) >= 3 else 0.0,
config["alpha"],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......@@ -554,6 +575,8 @@ class Optimizer2State(Optimizer8bit):
state["state2"],
config["betas"][0],
config["betas"][1],
config["betas"][2] if len(config["betas"]) >= 3 else 0.0,
config["alpha"],
config["eps"],
step,
config["lr"],
......@@ -726,6 +749,8 @@ class Optimizer1State(Optimizer8bit):
config["lr"],
None,
config["betas"][1],
0.0,
0.0,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
......@@ -767,6 +792,8 @@ class Optimizer1State(Optimizer8bit):
None,
config["betas"][0],
config["betas"][1],
0.0,
0.0,
config["eps"],
step,
config["lr"],
......
......@@ -874,7 +874,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
......@@ -885,9 +885,16 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD];
// AdEMAMix has an additional state buffer, which we packed
// into state1. We need thread-local storage here for these.
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
......@@ -926,6 +933,13 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
// Load additional state1 data for AdEMAMix
// TODO: Make constexpr after updating min compiler
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
......@@ -935,7 +949,28 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
{
switch(OPTIMIZER)
{
case ADEMAMIX:
// m1 update: m1 = beta1 * m1 + (1-beta1) * g
s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]);
// m2 update: m2 = m2 * beta3 + (1-beta3) * g
s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]);
// nu update: nu = beta2 * nu + (1-beta2) * g^2
s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]);
p_vals[j] = (float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
);
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
break;
case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
......@@ -955,6 +990,11 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items);
}
}
}
......@@ -1644,14 +1684,27 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
kOptimizerStatic8bit2StateBlockwise(
T* p,
T* __restrict__ const g,
unsigned char* state1,
unsigned char* state2,
const float beta1,
const float beta2,
const float beta3,
const float alpha,
const float eps,
const int step,
const float lr,
float* __restrict__ const quantiles1,
float* __restrict__ const quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
const bool skip_zeros,
const int n
) {
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
......@@ -1660,6 +1713,8 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
float g_val = 0.0f;
float s1_vals[N_PER_TH];
float s2_vals[N_PER_TH];
float s3_vals[N_PER_TH];
// 2-5%
const float correction1 = 1.0f - __powf(beta1, step);
const float correction2 = sqrtf(1.0f -__powf(beta2, step));
......@@ -1667,11 +1722,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float new_local_abs_max2 = -FLT_MAX;
float new_local_abs_max3 = -FLT_MAX;
float quadrants1[QUAD];
float quadrants2[QUAD];
unsigned char c1s[N_PER_TH];
unsigned char c2s[N_PER_TH];
unsigned char c3s[N_PER_TH];
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
......@@ -1684,10 +1742,13 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
__shared__ float smem_quantiles2[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce3;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ typename BlockReduce2::TempStorage reduce3;
__shared__ float smem_exchange1[1];
__shared__ float smem_exchange2[1];
__shared__ float smem_exchange3[1]; // [[maybe_unused]]
__shared__ union {
typename LoadT::TempStorage loadh;
......@@ -1728,8 +1789,15 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
// AdEMAMix has an additional state packed into state1.
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128);
}
new_local_abs_max1 = -FLT_MAX;
new_local_abs_max2 = -FLT_MAX;
new_local_abs_max3 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
......@@ -1747,15 +1815,29 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
if (OPTIMIZER == ADEMAMIX) {
// The absmax for the third state is appended to absmax1
s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE];
s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val));
}
}
else
{
s1_vals[j] = 0.0f;
s2_vals[j] = 0.0f;
if (OPTIMIZER == ADEMAMIX) {
s3_vals[j] = 0.0f;
}
}
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j]));
}
}
......@@ -1763,10 +1845,18 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max());
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max());
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max());
}
if(threadIdx.x == 0)
{
smem_exchange1[0] = new_local_abs_max1;
smem_exchange2[0] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) {
smem_exchange3[0] = new_local_abs_max3;
}
}
__syncthreads();
......@@ -1775,11 +1865,19 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) {
absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3;
}
}
else
{
new_local_abs_max1 = smem_exchange1[0];
new_local_abs_max2 = smem_exchange2[0];
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = smem_exchange3[0];
}
}
__syncthreads();
......@@ -1791,8 +1889,17 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if(weight_decay > 0.0f)
if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
));
} else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
}
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
}
......@@ -1817,12 +1924,25 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
c1s[j] -= 1;
}
if (OPTIMIZER == ADEMAMIX) {
c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3));
if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) {
c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1;
}
}
}
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items);
}
}
}
......@@ -3740,13 +3860,23 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
......@@ -3904,7 +4034,7 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(fl
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
......@@ -3914,6 +4044,9 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048, 8)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
......
......@@ -27,7 +27,8 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const float beta1, const float beta2, const float beta3, const float alpha,
const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
......@@ -89,7 +90,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
......
......@@ -94,7 +94,7 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int num_blocks = n/4096;
......@@ -102,13 +102,14 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
switch(OPTIMIZER)
{
case ADAM:
case ADEMAMIX:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
......@@ -195,19 +196,40 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
T* g,
unsigned char* state1,
unsigned char* state2,
float beta1,
float beta2,
float beta3,
float alpha,
float eps,
int step,
float lr,
float* quantiles1,
float* quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
bool skip_zeros,
int n
) {
int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros, n
);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
......@@ -787,7 +809,8 @@ template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
......@@ -802,6 +825,9 @@ MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
......@@ -827,7 +853,7 @@ MAKE_optimizerStatic8bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
......@@ -842,6 +868,9 @@ MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
......
......@@ -72,6 +72,7 @@ typedef enum Optimizer_t
LARS = 3,
ADAGRAD = 4,
LION = 5,
ADEMAMIX = 6
} Optimizer_t;
typedef enum Transform_t
......@@ -149,7 +150,7 @@ template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsign
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay,
float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay,
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
......@@ -162,7 +163,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
const float gnorm_scale, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);
......
......@@ -52,9 +52,10 @@ MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
......@@ -68,6 +69,10 @@ MAKE_FUNC32(lion, LION, half, fp16)
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
MAKE_FUNC32(ademamix, ADEMAMIX, float, fp32)
MAKE_FUNC32(ademamix, ADEMAMIX, half, fp16)
MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
......@@ -93,9 +98,9 @@ MAKE_FUNC8(lion, LION, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
......@@ -109,6 +114,9 @@ MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
......@@ -224,9 +232,10 @@ extern "C"
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
......@@ -240,6 +249,9 @@ extern "C"
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16)
MAKE_CFUNC32(ademamix, float, fp32)
MAKE_CFUNC32(ademamix, half, fp16)
MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16)
#define MAKE_CFUNC8(name, gtype, gbits) \
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
......@@ -265,9 +277,9 @@ extern "C"
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
......@@ -281,6 +293,9 @@ extern "C"
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, float, fp32)
MAKE_CBLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
......
......@@ -42,6 +42,8 @@
title: Adam
- local: reference/optim/adamw
title: AdamW
- local: reference/optim/ademamix
title: AdEMAMix
- local: reference/optim/lamb
title: LAMB
- local: reference/optim/lars
......
# AdEMAMix
[AdEMAMix](https://hf.co/papers/2409.03137) is a variant of the [`Adam`] optimizer.
bitsandbytes also supports paged optimizers which take advantage of CUDAs unified memory to transfer memory from the GPU to the CPU when GPU memory is exhausted.
## AdEMAMix[[api-class]]
[[autodoc]] bitsandbytes.optim.AdEMAMix
- __init__
## AdEMAMix8bit
[[autodoc]] bitsandbytes.optim.AdEMAMix8bit
- __init__
## AdEMAMix32bit
[[autodoc]] bitsandbytes.optim.AdEMAMix32bit
- __init__
## PagedAdEMAMix
[[autodoc]] bitsandbytes.optim.PagedAdEMAMix
- __init__
## PagedAdEMAMix8bit
[[autodoc]] bitsandbytes.optim.PagedAdEMAMix8bit
- __init__
## PagedAdEMAMix32bit
[[autodoc]] bitsandbytes.optim.PagedAdEMAMix32bit
- __init__
......@@ -36,6 +36,8 @@ def rm_path(path):
str2optimizers = {}
## TODO: maybe remove these three.
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
......@@ -43,45 +45,67 @@ str2optimizers["momentum_pytorch"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["paged_adamw8bit_blockwise"] = (
torch.optim.AdamW,
lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.AdEMAMix)
str2optimizers["ademamix8bit_blockwise"] = (
bnb.optim.ademamix._ReferenceAdEMAMix,
lambda pxx: bnb.optim.AdEMAMix8bit(pxx),
)
str2optimizers["paged_ademamix"] = (bnb.optim.ademamix._ReferenceAdEMAMix, bnb.optim.PagedAdEMAMix)
str2optimizers["paged_ademamix8bit_blockwise"] = (
bnb.optim.ademamix._ReferenceAdEMAMix,
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx),
)
str2optimizers["ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (
torch.optim.AdamW,
lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True),
)
str2optimizers["paged_adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
......@@ -118,7 +142,29 @@ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "ab
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
optimizer_names_32bit = ["adam", "momentum", "rmsprop", "paged_adamw", "paged_adam", "lion", "paged_lion"]
str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
]
str2statenames["paged_ademamix8bit_blockwise"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
]
optimizer_names_32bit = [
"adam",
"paged_adamw",
"paged_adam",
"momentum",
"rmsprop",
"lion",
"paged_lion",
"ademamix",
"ademamix_scheduled",
"paged_ademamix",
]
@pytest.mark.parametrize("optim_name", optimizer_names_32bit, ids=id_formatter("opt"))
......@@ -251,6 +297,8 @@ optimizer_names_8bit = [
"lion8bit_blockwise",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
"ademamix8bit_blockwise",
"ademamix8bit_blockwise_scheduled",
]
......@@ -259,7 +307,13 @@ optimizer_names_8bit = [
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in ["adam8bit_blockwise", "lion8bit_blockwise"]:
torch.set_printoptions(precision=6)
if gtype == torch.bfloat16 and optim_name not in [
"adam8bit_blockwise",
"lion8bit_blockwise",
"ademamix8bit_blockwise",
]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
......@@ -284,7 +338,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
errors = []
relerrors = []
for i in range(100):
for i in range(50):
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
......@@ -293,19 +347,38 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
# and AdEMAMix can diverge as well, allow up to 0.05% errors.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
# print(bnb_optimizer.state[p2][max_val], name1)
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked.
if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
m1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][0],
A=bnb_optimizer.state[p2][name2][0],
blocksize=blocksize,
)
m2 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][1],
A=bnb_optimizer.state[p2][name2][1],
blocksize=blocksize,
)
s1 = torch.stack((m1, m2))
else:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
......@@ -320,10 +393,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015
assert relerr.mean() < 0.0016
assert relerr.mean() < 0.0020 # 0.0016
else:
assert err.mean() < 0.00012
assert relerr.mean() < 0.0012
assert err.mean() < 0.00016 # 0.00012
assert relerr.mean() < 0.0016 # 0.0012
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
......@@ -345,12 +418,32 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1]
## separately and then stack them. The qmap is shared, but absmax is also stacked.
if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2":
s1 = torch.stack(
(
F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][0],
A=bnb_optimizer.state[p2][name2][0],
blocksize=blocksize,
),
F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val][1],
A=bnb_optimizer.state[p2][name2][1],
blocksize=blocksize,
),
)
)
else:
s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else:
s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
......@@ -362,8 +455,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
# and AdEMAMix can also be noisy, allow up to 0.05%.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
......@@ -469,6 +562,7 @@ optimizer_names_benchmark = [
"paged_adam8bit_blockwise",
"paged_adamw8bit_blockwise",
"paged_lion8bit_blockwise",
"paged_ademamix8bit_blockwise",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment