Unverified Commit 404e2776 authored by Egor's avatar Egor Committed by GitHub
Browse files

[XPU] Implemented 8bit optimizers in triton (#1692)

* implemented 8bit optimizers

* Add interface

* Commented out torch checks

* Merged

* Updated kernels

* Reused code for quant/dequant

* Removed empty line

* Changed Readme
parent 4b025748
...@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms: ...@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms:
</td> </td>
<td></td> <td></td>
<td></td> <td></td>
<td>🚧</td> <td>〰️</td>
</tr> </tr>
<tr> <tr>
<td colspan="6">🍎 <strong>macOS 14+</strong></td> <td colspan="6">🍎 <strong>macOS 14+</strong></td>
......
...@@ -4,167 +4,6 @@ import triton ...@@ -4,167 +4,6 @@ import triton
import triton.language as tl import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)
# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax
offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)
def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 # Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
...@@ -587,7 +426,7 @@ def dequant_nf4_kernel( ...@@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask) tl.store(c_ptr + offs, out_dq, mask)
def _dequantize_4bit_impl( def dequantize_4bit_impl(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
blocksize: int, blocksize: int,
...@@ -611,7 +450,7 @@ def _dequantize_4bit_impl( ...@@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
def _dequantize_4bit_impl_passing_code( def dequantize_4bit_impl_passing_code(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
blocksize: int, blocksize: int,
......
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
out_ptr,
code_ptr,
absmax_ptr,
n,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < n
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
tl.store(out_ptr + offsets, out_dq, mask)
def dequant_8bit_blockwise(
a: torch.Tensor,
absmax: torch.Tensor,
quant_state_code: torch.Tensor,
quant_blocksize: int = 64,
dtype: torch.dtype = None,
out: torch.Tensor = None,
):
n = a.numel()
if out is None:
if dtype is None:
raise ValueError("If out is None, dtype must be specified")
out = torch.empty_like(a, dtype=dtype, device=a.device)
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(n, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
a,
out,
quant_state_code,
absmax,
n,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_8bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
tl.store(out_ptr + offsets, quantized, mask=mask)
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
n = A.numel()
blocks = -(n // -blocksize)
if absmax is None:
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
if out is None:
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_8bit_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
# num_warps=1,
# num_stages=2,
)
out = out.reshape(A.shape)
return out, absmax
@triton.jit
def quantize_8bit_blockwise_kernel_util(
a,
code_ptr,
CODE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
# Calculating absmax for each block
absmax = tl.max(tl.abs(a_reshaped), axis=1)
a_normalized = a_reshaped / absmax[:, None]
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
for _ in range(8):
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = a_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(a_normalized - lower_val)
upper_dist = tl.abs(a_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
return quantized_flat, absmax
@triton.jit
def dequant_8bit_blockwise_kernel_util(
a_ptr,
offsets,
code_ptr,
absmax_ptr,
mask,
BLOCK_SIZE: tl.constexpr,
):
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
scaled_int8 = tl.load(code_ptr + a, mask)
# Load scales
absmax_offsets = offsets // BLOCK_SIZE
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
# Apply scales
out_dq = scaled_int8 * absmax
return out_dq
import math
from typing import Optional from typing import Optional
import torch import torch
...@@ -6,6 +7,12 @@ import triton ...@@ -6,6 +7,12 @@ import triton
import triton.language as tl import triton.language as tl
# from triton.language.extra import libdevice # from triton.language.extra import libdevice
from .kernels_8bit_quant import (
dequant_8bit_blockwise,
dequant_8bit_blockwise_kernel_util,
quantize_8bit_blockwise_kernel_util,
quantize_blockwise_triton,
)
MOMENTUM = 0 MOMENTUM = 0
RMSPROP = 1 RMSPROP = 1
...@@ -470,3 +477,678 @@ def optimizer_update_32bit_impl( ...@@ -470,3 +477,678 @@ def optimizer_update_32bit_impl(
N_PER_TH, N_PER_TH,
num_warps=2, num_warps=2,
) )
###########################################
# Pure torch implementation for reference #
###########################################
@torch.compile
def _dequantize_blockwise_pytorch(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Pure PyTorch reference implementation for block-wise dequantization.
"""
if A.numel() == 0:
return torch.empty_like(A, dtype=dtype)
A_flat = A.flatten()
num_elements = A_flat.numel()
dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype)
num_blocks = math.ceil(num_elements / blocksize)
pad_len = num_blocks * blocksize - num_elements
if pad_len > 0:
dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len))
dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize)
rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype)
rescaled_flat = rescaled_blocks.flatten()
if pad_len > 0:
rescaled_flat = rescaled_flat[:-pad_len]
return rescaled_flat.reshape(A.shape)
@torch.compile
def _quantize_blockwise_pytorch(
A: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pure PyTorch reference implementation for block-wise quantization.
"""
if A.numel() == 0:
return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device)
A_flat = A.flatten()
num_elements = A_flat.numel()
num_blocks = math.ceil(num_elements / blocksize)
pad_len = num_blocks * blocksize - num_elements
if pad_len > 0:
A_flat = torch.nn.functional.pad(A_flat, (0, pad_len))
A_blocks = A_flat.reshape(num_blocks, blocksize)
absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0]
absmax[absmax == 0] = 1.0
scaled_blocks = A_blocks / absmax
# Inefficient but straightforward quantization, takes a lot of memory
diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device))
quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8)
quantized_flat = quantized_indices.flatten()
if pad_len > 0:
quantized_flat = quantized_flat[:-pad_len]
return quantized_flat.reshape(A.shape), absmax.flatten()
# Main updated function
def optimizer_update_8bit_blockwise_pytorch(
p: torch.Tensor,
g: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float, # ADEMIX
alpha: float, # ADEMIX
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros: bool,
# ADEMIX
*,
optimizer_name: str,
) -> None:
"""
Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
This version ensures high-precision updates for float16 parameters.
"""
if skip_zeros:
raise ValueError("skip_zeros is not supported on XPU yet.")
blocksize = 256
with torch.no_grad():
# Dequantize states to perform updates in 32-bit precision
if optimizer_name == "ademamix" and absmax1.ndim == 2:
# For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32)
s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32)
state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
else:
state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32)
state2_fp32 = None
if state2 is not None:
state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32)
grad = g.float() * gnorm_scale
# Create a 32-bit copy of the parameter for high-precision updates
p_fp32 = p.data.float()
if optimizer_name == "adam":
state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
elif optimizer_name == "ademamix":
m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
nu_fp32 = state2_fp32
m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = math.sqrt(1.0 - beta2**step)
update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.add_(update, alpha=-lr)
state1_fp32 = torch.stack([m1_fp32, m2_fp32])
elif optimizer_name == "momentum":
grad.add_(p_fp32, alpha=weight_decay)
if step == 1:
state1_fp32.copy_(grad)
else:
state1_fp32.mul_(beta1).add_(grad)
p_fp32.add_(state1_fp32, alpha=-lr)
elif optimizer_name == "rmsprop":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
elif optimizer_name == "lion":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
p_fp32.add_(update_dir, alpha=-lr)
state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
elif optimizer_name == "adagrad":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.addcmul_(grad, grad, value=1.0)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
else:
raise NotImplementedError(
f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
)
# Copy the updated 32-bit parameter back to the original tensor
p.data.copy_(p_fp32)
# Re-quantize states and update state tensors in-place
if optimizer_name == "ademamix":
new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize)
new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize)
state1[0].copy_(new_m1_8bit)
state1[1].copy_(new_m2_8bit)
absmax1[0].copy_(new_absmax_m1)
absmax1[1].copy_(new_absmax_m2)
new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
else:
new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize)
state1.copy_(new_state1_8bit)
absmax1.copy_(new_absmax1)
if state2_fp32 is not None:
new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
#######################################
# Mixed torch + triton implementation #
#######################################
# Much more memory efficient due to using triton for quantization/dequantization
def optimizer_update_8bit_blockwise_triton_quant(
p: torch.Tensor,
g: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float, # ADEMIX
alpha: float, # ADEMIX
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros: bool,
# ADEMIX
*,
optimizer_name: str,
) -> None:
"""
Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
This version ensures high-precision updates for float16 parameters.
"""
if skip_zeros and not torch.any(g):
return
blocksize = 256
grad = g.float() * gnorm_scale
with torch.no_grad():
# Create a 32-bit copy of the parameter for high-precision updates
p_fp32 = p.data.float()
# Dequantize states to perform updates in 32-bit precision
if optimizer_name == "ademamix" and absmax1.ndim == 2:
# For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32)
s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32)
state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
else:
state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32)
state2_fp32 = None
if state2 is not None:
state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32)
# Apply optimizer-specific update logic
if optimizer_name == "adam":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = 1.0 - beta2**step
denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
elif optimizer_name == "ademamix":
m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
nu_fp32 = state2_fp32
m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
bias_correction1 = 1.0 - beta1**step
bias_correction2 = math.sqrt(1.0 - beta2**step)
update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
p_fp32.add_(update, alpha=-lr)
state1_fp32 = torch.stack([m1_fp32, m2_fp32])
elif optimizer_name == "momentum":
grad.add_(p_fp32, alpha=weight_decay)
if step == 1:
state1_fp32.copy_(grad)
else:
state1_fp32.mul_(beta1).add_(grad)
p_fp32.add_(state1_fp32, alpha=-lr)
elif optimizer_name == "rmsprop":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
elif optimizer_name == "lion":
if weight_decay > 0.0:
p_fp32.mul_(1.0 - lr * weight_decay)
update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
p_fp32.add_(update_dir, alpha=-lr)
state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
elif optimizer_name == "adagrad":
grad.add_(p_fp32, alpha=weight_decay)
state1_fp32.addcmul_(grad, grad, value=1.0)
p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
else:
raise NotImplementedError(
f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
)
# Copy the updated 32-bit parameter back to the original tensor
p.data.copy_(p_fp32)
# Re-quantize states and update state tensors in-place
if optimizer_name == "ademamix":
new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize)
new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize)
state1[0].copy_(new_m1_8bit)
state1[1].copy_(new_m2_8bit)
absmax1[0].copy_(new_absmax_m1)
absmax1[1].copy_(new_absmax_m2)
new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
else:
new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize)
state1.copy_(new_state1_8bit)
absmax1.copy_(new_absmax1)
if state2_fp32 is not None:
new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
state2.copy_(new_state2_8bit)
absmax2.copy_(new_absmax2)
#########################
# Triton implementation #
#########################
@triton.jit
def _optimizer_update_1state_8bit_blockwise_triton_kernel(
# Tensors
p_ptr,
g_ptr,
state1_ptr,
state2_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
beta3,
alpha,
eps: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
qmap1_ptr,
qmap2_ptr,
absmax1_ptr,
absmax2_ptr,
weight_decay,
gnorm_scale,
# Meta-parameters
n_elements,
BLOCK_SIZE_N: tl.constexpr,
N_PER_TH: tl.constexpr,
OPTIMIZER_ID: tl.constexpr,
):
"""
Triton kernel for 8-bit optimizers that use one momentum state.
Supports: Momentum, RMSprop, Adagrad, Lion.
"""
# 1. Boilerplate: pid, offsets, mask
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
mask = offsets < n_elements
# 2. Load and dequantize tensors
g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
# 3. Optimizer-specific updates
# LION
if weight_decay > 0.0 and OPTIMIZER_ID == 2:
p *= 1.0 - lr * weight_decay
# Apply weight decay for momentum, rmsprop, adagrad
elif weight_decay > 0.0:
g += p * weight_decay
# Momentum update
if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1 = g
else:
s1 = s1 * beta1 + g
p -= lr * s1
# RMSprop update
elif OPTIMIZER_ID == 1: # RMSPROP
s1 = s1 * beta1 + (1.0 - beta1) * g * g
p -= lr * (g / (tl.sqrt(s1) + eps))
# Adagrad update
elif OPTIMIZER_ID == 2: # ADAGRAD
s1 += g * g
p -= lr * (g / (tl.sqrt(s1) + eps))
# Lion update
elif OPTIMIZER_ID == 4: # LION
val = s1 * beta1 + (1.0 - beta1) * g
update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0))
p -= lr * update
s1 = s1 * beta2 + (1.0 - beta2) * g
# 4. Store updated parameter and requantized state
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
@triton.jit
def _optimizer_update_2state_8bit_blockwise_triton_kernel(
# Tensors
p_ptr,
g_ptr,
state1_ptr,
state2_ptr,
beta1: tl.constexpr,
beta2: tl.constexpr,
# ademamix changes alpha and beta3
beta3,
# ademamix changes alpha and beta3
alpha,
eps: tl.constexpr,
step,
beta1_step,
beta2_step,
lr,
qmap1_ptr,
qmap2_ptr,
absmax1_ptr,
absmax2_ptr,
weight_decay: tl.constexpr,
gnorm_scale: tl.constexpr,
# Meta-parameters
n_elements,
BLOCK_SIZE_N: tl.constexpr,
N_PER_TH: tl.constexpr,
OPTIMIZER_ID: tl.constexpr,
):
"""
Triton kernel for 8-bit optimizers that use two momentum states.
Supports: Adam, AdEMAMix.
"""
# 1. Boilerplate: pid, offsets, mask
pid = tl.program_id(axis=0)
block_start_idx = pid * N_PER_TH
offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
mask = offsets < n_elements
# 2. Load and dequantize tensors
g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
# 3. Optimizer-specific updates
if OPTIMIZER_ID == 3: # ADAM
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
s1 = s1 * beta1 + (1.0 - beta1) * g
s2 = s2 * beta2 + (1.0 - beta2) * g * g
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
# bias_correction1 = 1.0 - libdevice.pow(beta1, step)
# bias_correction2 = 1.0 - libdevice.pow(beta2, step)
bias_correction1 = 1.0 - beta1_step
bias_correction2 = 1.0 - beta2_step
if weight_decay > 0.0:
p *= 1.0 - lr * weight_decay
denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps
p -= (lr / bias_correction1) * (s1 / denom)
# Store updated parameter
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
# Requantize and store states
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state2_ptr + offsets, s2_codes, mask=mask)
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2)
elif OPTIMIZER_ID == 5: # ADEMAMIX
# AdEMAMix has a stacked state1 (m1, m2) and state2 (nu)
m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
m2 = dequant_8bit_blockwise_kernel_util(
state1_ptr + n_elements,
offsets,
qmap1_ptr,
absmax1_ptr + n_elements // BLOCK_SIZE_N,
mask,
BLOCK_SIZE_N,
)
nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
m1 = m1 * beta1 + (1.0 - beta1) * g
m2 = m2 * beta3 + (1.0 - beta3) * g
nu = nu * beta2 + (1.0 - beta2) * g * g
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
# bias_correction1 = 1.0 - libdevice.pow(beta1, step)
# bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step))
bias_correction1 = 1.0 - beta1_step
bias_correction2 = tl.sqrt(1.0 - beta2_step)
update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps)
if weight_decay > 0.0:
p *= 1.0 - lr * weight_decay
p -= lr * update
# Store updated parameter
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
# Requantize and store all three states
m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + offsets, m1_codes, mask=mask)
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1)
m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask)
tl.store(
absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N,
new_absmax_m2,
)
nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
tl.store(state2_ptr + offsets, nu_codes, mask=mask)
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu)
name2optimizer_fn = {
"momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
"lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
}
def optimizer_update_8bit_blockwise_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported on XPU yet")
if optimizer_name == "ademamix":
# Handle AdEMAMIX's stacked state tensors
if state1.dim() < 2 or state1.shape[0] != 2:
raise ValueError(
f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}"
)
if absmax1.dim() < 2 or absmax1.shape[0] != 2:
raise ValueError(
f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}"
)
BLOCK_SIZE = 256
N_PER_TH = 1 # Number of blocks processed per thread.
grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),)
fn = name2optimizer_fn[optimizer_name]
optimizer_id = name2optimizer_id[optimizer_name]
# In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
# For backwards compatibility we precompute the bias correction factors.
beta1_step = beta1**step
beta2_step = beta2**step
fn[grid](
p,
g,
state1,
state2,
beta1,
beta2,
beta3,
alpha,
eps,
step,
beta1_step,
beta2_step,
lr,
qmap1,
qmap2,
absmax1,
absmax2,
weight_decay,
gnorm_scale,
p.numel(),
BLOCK_SIZE_N=BLOCK_SIZE,
N_PER_TH=N_PER_TH,
OPTIMIZER_ID=optimizer_id,
num_warps=2,
)
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from . import kernels_optim, triton_kernels from . import kernels_4bit, kernels_8bit_quant, kernels_optim
# currently codes unused, kept for reference # currently codes unused, kept for reference
# Should be the same for quant/dequant # Should be the same for quant/dequant
...@@ -17,19 +17,9 @@ torch_accelerator_module = getattr(torch, device_type, torch.cuda) ...@@ -17,19 +17,9 @@ torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
return out, absmax.float()
out = out.reshape(A.shape)
return out, absmax.float()
def dequantize_blockwise( def dequantize_blockwise(
...@@ -38,22 +28,24 @@ def dequantize_blockwise( ...@@ -38,22 +28,24 @@ def dequantize_blockwise(
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise( out = kernels_8bit_quant.dequant_8bit_blockwise(
A, A,
code,
absmax, absmax,
out, code,
blocksize, blocksize,
dtype=dtype,
) )
return out return out
def dequantize_blockwise_inplace( def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None: ) -> None:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
...@@ -62,12 +54,13 @@ def dequantize_blockwise_inplace( ...@@ -62,12 +54,13 @@ def dequantize_blockwise_inplace(
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise( kernels_8bit_quant.dequant_8bit_blockwise(
A, A,
code,
absmax, absmax,
out, code,
blocksize, blocksize,
dtype=dtype,
out=out,
) )
...@@ -92,7 +85,7 @@ def quantize_4bit( ...@@ -92,7 +85,7 @@ def quantize_4bit(
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.quantize_4bit_blockwise_triton( kernels_4bit.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
) )
packed = out packed = out
...@@ -126,9 +119,8 @@ def dequantize_4bit( ...@@ -126,9 +119,8 @@ def dequantize_4bit(
A = A.squeeze().view(torch.uint8).unsqueeze(1) A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device) out = torch.empty(shape, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out return out
...@@ -145,7 +137,7 @@ def dequantize_4bit_inplace( ...@@ -145,7 +137,7 @@ def dequantize_4bit_inplace(
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit( def gemv_4bit(
...@@ -162,7 +154,7 @@ def gemv_4bit( ...@@ -162,7 +154,7 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl_passing_code( kernels_4bit.dequantize_4bit_impl_passing_code(
B, B,
absmax, absmax,
blocksize, blocksize,
...@@ -171,11 +163,95 @@ def gemv_4bit( ...@@ -171,11 +163,95 @@ def gemv_4bit(
out=B_dq_triton, out=B_dq_triton,
) )
return torch.nn.functional.linear( return torch.nn.functional.linear(
A, A,
B_dq_triton, B_dq_triton,
bias=None, bias=None,
) )
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms
optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam
def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
with torch_accelerator_module.device(state1.device):
optimizer_update_8bit_blockwise_impl(
optimizer_name=optimizer_name,
g=g,
p=p,
state1=state1,
state2=state2,
beta1=beta1,
beta2=beta2,
beta3=beta3,
alpha=alpha,
eps=eps,
step=step,
lr=lr,
qmap1=qmap1,
qmap2=qmap2,
absmax1=absmax1,
absmax2=absmax2,
weight_decay=weight_decay,
gnorm_scale=gnorm_scale,
skip_zeros=skip_zeros,
)
def optimizer_update_32bit( def optimizer_update_32bit(
......
...@@ -156,6 +156,10 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): ...@@ -156,6 +156,10 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(
triton_ops.optimizer_update_8bit_blockwise
)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
@register_kernel("bitsandbytes::dequantize_4bit", "xpu") @register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _( def _(
...@@ -232,6 +236,7 @@ elif triton_available: ...@@ -232,6 +236,7 @@ elif triton_available:
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else: else:
logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.") logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment