Unverified Commit 404e2776 authored by Egor's avatar Egor Committed by GitHub
Browse files

[XPU] Implemented 8bit optimizers in triton (#1692)

* implemented 8bit optimizers

* Add interface

* Commented out torch checks

* Merged

* Updated kernels

* Reused code for quant/dequant

* Removed empty line

* Changed Readme
parent 4b025748
......@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms:
</td>
<td></td>
<td></td>
<td>🚧</td>
<td>〰️</td>
</tr>
<tr>
<td colspan="6">🍎 <strong>macOS 14+</strong></td>
......
......@@ -4,167 +4,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)
# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax
offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)
def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
......@@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask)
def _dequantize_4bit_impl(
def dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
......@@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
def _dequantize_4bit_impl_passing_code(
def dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
......
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
out_ptr,
code_ptr,
absmax_ptr,
n,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < n
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
tl.store(out_ptr + offsets, out_dq, mask)
def dequant_8bit_blockwise(
a: torch.Tensor,
absmax: torch.Tensor,
quant_state_code: torch.Tensor,
quant_blocksize: int = 64,
dtype: torch.dtype = None,
out: torch.Tensor = None,
):
n = a.numel()
if out is None:
if dtype is None:
raise ValueError("If out is None, dtype must be specified")
out = torch.empty_like(a, dtype=dtype, device=a.device)
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(n, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
a,
out,
quant_state_code,
absmax,
n,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_8bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
tl.store(out_ptr + offsets, quantized, mask=mask)
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
n = A.numel()
blocks = -(n // -blocksize)
if absmax is None:
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
if out is None:
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_8bit_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
# num_warps=1,
# num_stages=2,
)
out = out.reshape(A.shape)
return out, absmax
@triton.jit
def quantize_8bit_blockwise_kernel_util(
a,
code_ptr,
CODE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
# Calculating absmax for each block
absmax = tl.max(tl.abs(a_reshaped), axis=1)
a_normalized = a_reshaped / absmax[:, None]
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
for _ in range(8):
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = a_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(a_normalized - lower_val)
upper_dist = tl.abs(a_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
return quantized_flat, absmax
@triton.jit
def dequant_8bit_blockwise_kernel_util(
a_ptr,
offsets,
code_ptr,
absmax_ptr,
mask,
BLOCK_SIZE: tl.constexpr,
):
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
scaled_int8 = tl.load(code_ptr + a, mask)
# Load scales
absmax_offsets = offsets // BLOCK_SIZE
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
# Apply scales
out_dq = scaled_int8 * absmax
return out_dq
import math
from typing import Optional
import torch
......@@ -6,6 +7,12 @@ import triton
import triton.language as tl
# from triton.language.extra import libdevice
from .kernels_8bit_quant import (
dequant_8bit_blockwise,
dequant_8bit_blockwise_kernel_util,
quantize_8bit_blockwise_kernel_util,
quantize_blockwise_triton,
)
MOMENTUM = 0
RMSPROP = 1
......@@ -470,3 +477,678 @@ def optimizer_update_32bit_impl(
N_PER_TH,
num_warps=2,
)
###########################################
# Pure torch implementation for reference #
###########################################
@torch.compile
def _dequantize_blockwise_pytorch(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Pure PyTorch reference implementation for block-wise dequantization.
"""
if A.numel() == 0:
return torch.empty_like(A, dtype=dtype)
A_flat = A.flatten()
num_elements = A_flat.numel()
dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype)
num_blocks = math.ceil(num_elements / blocksize)
pad_len = num_blocks * blocksize - num_elements
if pad_len > 0:
dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len))
dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize)
rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype)
rescaled_flat = rescaled_blocks.flatten()
if pad_len > 0:
rescaled_flat = rescaled_flat[:-pad_len]
return rescaled_flat.reshape(A.shape)
@torch.compile
def _quantize_blockwise_pytorch(
A: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pure PyTorch reference implementation for block-wise quantization.
"""
if A.numel() == 0:
return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device)
A_flat = A.flatten()
num_elements = A_flat.numel()
num_blocks = math.ceil(num_elements / blocksize)
pad_len = num_blocks * blocksize - num_elements
if pad_len > 0:
A_flat = torch.nn.functional.pad(A_flat, (0, pad_len))
A_blocks = A_flat.reshape(num_blocks, blocksize)
absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0]
absmax[absmax == 0] = 1.0
scaled_blocks = A_blocks / absmax
# Inefficient but straightforward quantization, takes a lot of memory
diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device))
quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8)
quantized_flat = quantized_indices.flatten()
if pad_len > 0:
quantized_flat = quantized_flat[:-pad_len]
return quantized_flat.reshape(A.shape), absmax.flatten()
# Main updated function
def optimizer_update_8bit_blockwise_pytorch(
p: torch.Tensor,
g: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float, # ADEMIX
alpha: float, # ADEMIX
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros: bool,
# ADEMIX
*,
optimizer_name: str,
) -> None:
"""
Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
This version ensures high-precision updates for float16 parameters.
"""
if skip_zeros:
raise ValueError("skip_zeros is not supported on XPU yet.")
blocksize = 256
with torch.no_grad():
# Dequantize states to perform updates in 32-bit precision
if optimizer_name == "ademamix" and absmax1.ndim == 2:
# For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32)
s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32)
state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
else:
state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32)
state2_fp32 = None
if state2 is not None:
state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32)
grad = g.float() * gnorm_scale
# Create a 32-bit copy of the parameter for high-precision updates
p_fp32 = p.data.float()
if optimizer_name == "adam":
state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
elif optimizer_name == "ademamix":
m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
nu_fp32 = state2_fp32
m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = math.sqrt(1.0 - beta2**step)
update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.add_(update, alpha=-lr)
state1_fp32 = torch.stack([m1_fp32, m2_fp32])
elif optimizer_name == "momentum":
grad.add_(p_fp32, alpha=weight_decay)
if step == 1:
state1_fp32.copy_(grad)
else:
state1_fp32.mul_(beta1).add_(grad)
p_fp32.add_(state1_fp32, alpha=-lr)
elif optimizer_name == "rmsprop":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
elif optimizer_name == "lion":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
p_fp32.add_(update_dir, alpha=-lr)
state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
elif optimizer_name == "adagrad":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.addcmul_(grad, grad, value=1.0)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
else:
raise NotImplementedError(
f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
)
# Copy the updated 32-bit parameter back to the original tensor
p.data.copy_(p_fp32)
# Re-quantize states and update state tensors in-place
if optimizer_name == "ademamix":
new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize)
new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize)
state1[0].copy_(new_m1_8bit)
state1[1].copy_(new_m2_8bit)
absmax1[0].copy_(new_absmax_m1)
absmax1[1].copy_(new_absmax_m2)
new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
else:
new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize)
state1.copy_(new_state1_8bit)
absmax1.copy_(new_absmax1)
if state2_fp32 is not None:
new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
#######################################
# Mixed torch + triton implementation #
#######################################
# Much more memory efficient due to using triton for quantization/dequantization
def optimizer_update_8bit_blockwise_triton_quant(
p: torch.Tensor,
g: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float, # ADEMIX
alpha: float, # ADEMIX
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros: bool,
# ADEMIX
*,
optimizer_name: str,
) -> None:
"""
Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
This version ensures high-precision updates for float16 parameters.
"""
if skip_zeros and not torch.any(g):
return
blocksize = 256
grad = g.float() * gnorm_scale
with torch.no_grad():
# Create a 32-bit copy of the parameter for high-precision updates
p_fp32 = p.data.float()
# Dequantize states to perform updates in 32-bit precision
if optimizer_name == "ademamix" and absmax1.ndim == 2:
# For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32)
s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32)
state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
else:
state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32)
state2_fp32 = None
if state2 is not None:
state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32)
# Apply optimizer-specific update logic
if optimizer_name == "adam":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
elif optimizer_name == "ademamix":
m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
nu_fp32 = state2_fp32
m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = math.sqrt(1.0 - beta2**step)
update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.add_(update, alpha=-lr)
state1_fp32 = torch.stack([m1_fp32, m2_fp32])
elif optimizer_name == "momentum":
grad.add_(p_fp32, alpha=weight_decay)
if step == 1:
state1_fp32.copy_(grad)
else:
state1_fp32.mul_(beta1).add_(grad)
p_fp32.add_(state1_fp32, alpha=-lr)
elif optimizer_name == "rmsprop":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
elif optimizer_name == "lion":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
p_fp32.add_(update_dir, alpha=-lr)
state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
elif optimizer_name == "adagrad":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.addcmul_(grad, grad, value=1.0)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
else:
raise NotImplementedError(
f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
)
# Copy the updated 32-bit parameter back to the original tensor
p.data.copy_(p_fp32)
# Re-quantize states and update state tensors in-place
if optimizer_name == "ademamix":
new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize)
new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize)
state1[0].copy_(new_m1_8bit)
state1[1].copy_(new_m2_8bit)
absmax1[0].copy_(new_absmax_m1)
absmax1[1].copy_(new_absmax_m2)
new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
else:
new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize)
state1.copy_(new_state1_8bit)
absmax1.copy_(new_absmax1)
if state2_fp32 is not None:
new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
#########################
# Triton implementation #
#########################
@triton.jit
def _optimizer_update_1state_8bit_blockwise_triton_kernel(
# Tensors
p_ptr,
g_ptr,
state1_ptr,
state2_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
beta3,
alpha,
eps: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
qmap1_ptr,
qmap2_ptr,
absmax1_ptr,
absmax2_ptr,
weight_decay,
gnorm_scale,
# Meta-parameters
n_elements,
BLOCK_SIZE_N: tl.constexpr,
N_PER_TH: tl.constexpr,
OPTIMIZER_ID: tl.constexpr,
):
"""
Triton kernel for 8-bit optimizers that use one momentum state.
Supports: Momentum, RMSprop, Adagrad, Lion.
"""
# 1. Boilerplate: pid, offsets, mask
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
mask = offsets < n_elements
# 2. Load and dequantize tensors
g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
# 3. Optimizer-specific updates
# LION
if weight_decay > 0.0 and OPTIMIZER_ID == 2:
p *= 1.0 - lr * weight_decay
# Apply weight decay for momentum, rmsprop, adagrad
elif weight_decay > 0.0:
g += p * weight_decay
# Momentum update
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1 = g
else:
s1 = s1 * beta1 + g
p -= lr * s1
# RMSprop update
elif OPTIMIZER_ID == 1: # RMSPROP
s1 = s1 * beta1 + (1.0 - beta1) * g * g
p -= lr * (g / (tl.sqrt(s1) + eps))
# Adagrad update
elif OPTIMIZER_ID == 2: # ADAGRAD
s1 += g * g
p -= lr * (g / (tl.sqrt(s1) + eps))
# Lion update
elif OPTIMIZER_ID == 4: # LION
val = s1 * beta1 + (1.0 - beta1) * g
update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0))
p -= lr * update
s1 = s1 * beta2 + (1.0 - beta2) * g
# 4. Store updated parameter and requantized state
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
@triton.jit
def _optimizer_update_2state_8bit_blockwise_triton_kernel(
# Tensors
p_ptr,
g_ptr,
state1_ptr,
state2_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
# ademamix changes alpha and beta3
beta3,
# ademamix changes alpha and beta3
alpha,
eps: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
qmap1_ptr,
qmap2_ptr,
absmax1_ptr,
absmax2_ptr,
weight_decay: tl.constexpr,
gnorm_scale: tl.constexpr,
# Meta-parameters
n_elements,
BLOCK_SIZE_N: tl.constexpr,
N_PER_TH: tl.constexpr,
OPTIMIZER_ID: tl.constexpr,
):
"""
Triton kernel for 8-bit optimizers that use two momentum states.
Supports: Adam, AdEMAMix.
"""
# 1. Boilerplate: pid, offsets, mask
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
mask = offsets < n_elements
# 2. Load and dequantize tensors
g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
# 3. Optimizer-specific updates
if OPTIMIZER_ID == 3: # ADAM
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
s1 = s1 * beta1 + (1.0 - beta1) * g
s2 = s2 * beta2 + (1.0 - beta2) * g * g
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
# bias_correction1 = 1.0 - libdevice.pow(beta1, step)
# bias_correction2 = 1.0 - libdevice.pow(beta2, step)
bias_correction1 = 1.0 - beta1_step
bias_correction2 = 1.0 - beta2_step
if weight_decay > 0.0:
p *= 1.0 - lr * weight_decay
denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps
p -= (lr / bias_correction1) * (s1 / denom)
# Store updated parameter
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
# Requantize and store states
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state2_ptr + offsets, s2_codes, mask=mask)
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2)
elif OPTIMIZER_ID == 5: # ADEMAMIX
# AdEMAMix has a stacked state1 (m1, m2) and state2 (nu)
m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
m2 = dequant_8bit_blockwise_kernel_util(
state1_ptr + n_elements,
offsets,
qmap1_ptr,
absmax1_ptr + n_elements // BLOCK_SIZE_N,
mask,
BLOCK_SIZE_N,
)
nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
m1 = m1 * beta1 + (1.0 - beta1) * g
m2 = m2 * beta3 + (1.0 - beta3) * g
nu = nu * beta2 + (1.0 - beta2) * g * g
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
# bias_correction1 = 1.0 - libdevice.pow(beta1, step)
# bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step))
bias_correction1 = 1.0 - beta1_step
bias_correction2 = tl.sqrt(1.0 - beta2_step)
update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps)
if weight_decay > 0.0:
p *= 1.0 - lr * weight_decay
p -= lr * update
# Store updated parameter
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
# Requantize and store all three states
m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, m1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1)
m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask)
tl.store(
absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N,
new_absmax_m2,
)
nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state2_ptr + offsets, nu_codes, mask=mask)
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu)
name2optimizer_fn = {
"momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
"lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
}
def optimizer_update_8bit_blockwise_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported on XPU yet")
if optimizer_name == "ademamix":
# Handle AdEMAMIX's stacked state tensors
if state1.dim() < 2 or state1.shape[0] != 2:
raise ValueError(
f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}"
)
if absmax1.dim() < 2 or absmax1.shape[0] != 2:
raise ValueError(
f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}"
)
BLOCK_SIZE = 256
N_PER_TH = 1 # Number of blocks processed per thread.
grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),)
fn = name2optimizer_fn[optimizer_name]
optimizer_id = name2optimizer_id[optimizer_name]
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
beta1_step = beta1**step
beta2_step = beta2**step
fn[grid](
p,
g,
state1,
state2,
beta1,
beta2,
beta3,
alpha,
eps,
step,
beta1_step,
beta2_step,
lr,
qmap1,
qmap2,
absmax1,
absmax2,
weight_decay,
gnorm_scale,
p.numel(),
BLOCK_SIZE_N=BLOCK_SIZE,
N_PER_TH=N_PER_TH,
OPTIMIZER_ID=optimizer_id,
num_warps=2,
)
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl
......@@ -3,7 +3,7 @@ from typing import Optional
import torch
from . import kernels_optim, triton_kernels
from . import kernels_4bit, kernels_8bit_quant, kernels_optim
# currently codes unused, kept for reference
# Should be the same for quant/dequant
......@@ -17,19 +17,9 @@ torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape)
return out, absmax.float()
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
return out, absmax.float()
def dequantize_blockwise(
......@@ -38,22 +28,24 @@ def dequantize_blockwise(
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
out = kernels_8bit_quant.dequant_8bit_blockwise(
A,
code,
absmax,
out,
code,
blocksize,
dtype=dtype,
)
return out
def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
......@@ -62,12 +54,13 @@ def dequantize_blockwise_inplace(
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise(
kernels_8bit_quant.dequant_8bit_blockwise(
A,
code,
absmax,
out,
code,
blocksize,
dtype=dtype,
out=out,
)
......@@ -92,7 +85,7 @@ def quantize_4bit(
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
with torch_accelerator_module.device(A.device):
triton_kernels.quantize_4bit_blockwise_triton(
kernels_4bit.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out
......@@ -126,9 +119,8 @@ def dequantize_4bit(
A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
......@@ -145,7 +137,7 @@ def dequantize_4bit_inplace(
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit(
......@@ -162,7 +154,7 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl_passing_code(
kernels_4bit.dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
......@@ -171,11 +163,95 @@ def gemv_4bit(
out=B_dq_triton,
)
return torch.nn.functional.linear(
A,
B_dq_triton,
bias=None,
)
return torch.nn.functional.linear(
A,
B_dq_triton,
bias=None,
)
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms
optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam
def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
with torch_accelerator_module.device(state1.device):
optimizer_update_8bit_blockwise_impl(
optimizer_name=optimizer_name,
g=g,
p=p,
state1=state1,
state2=state2,
beta1=beta1,
beta2=beta2,
beta3=beta3,
alpha=alpha,
eps=eps,
step=step,
lr=lr,
qmap1=qmap1,
qmap2=qmap2,
absmax1=absmax1,
absmax2=absmax2,
weight_decay=weight_decay,
gnorm_scale=gnorm_scale,
skip_zeros=skip_zeros,
)
def optimizer_update_32bit(
......
......@@ -156,6 +156,10 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(
triton_ops.optimizer_update_8bit_blockwise
)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _(
......@@ -232,6 +236,7 @@ elif triton_available:
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else:
logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment