Unverified Commit 404e2776 authored by Egor's avatar Egor Committed by GitHub
Browse files

[XPU] Implemented 8bit optimizers in triton (#1692)

* implemented 8bit optimizers

* Add interface

* Commented out torch checks

* Merged

* Updated kernels

* Reused code for quant/dequant

* Removed empty line

* Changed Readme
parent 4b025748
...@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms: ...@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms:
</td> </td>
<td></td> <td></td>
<td></td> <td></td>
<td>🚧</td> <td>〰️</td>
</tr> </tr>
<tr> <tr>
<td colspan="6">🍎 <strong>macOS 14+</strong></td> <td colspan="6">🍎 <strong>macOS 14+</strong></td>
......
...@@ -4,167 +4,6 @@ import triton ...@@ -4,167 +4,6 @@ import triton
import triton.language as tl import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)
# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax
offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)
def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 # Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
...@@ -587,7 +426,7 @@ def dequant_nf4_kernel( ...@@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask) tl.store(c_ptr + offs, out_dq, mask)
def _dequantize_4bit_impl( def dequantize_4bit_impl(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
blocksize: int, blocksize: int,
...@@ -611,7 +450,7 @@ def _dequantize_4bit_impl( ...@@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
def _dequantize_4bit_impl_passing_code( def dequantize_4bit_impl_passing_code(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
blocksize: int, blocksize: int,
......
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
out_ptr,
code_ptr,
absmax_ptr,
n,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < n
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
tl.store(out_ptr + offsets, out_dq, mask)
def dequant_8bit_blockwise(
a: torch.Tensor,
absmax: torch.Tensor,
quant_state_code: torch.Tensor,
quant_blocksize: int = 64,
dtype: torch.dtype = None,
out: torch.Tensor = None,
):
n = a.numel()
if out is None:
if dtype is None:
raise ValueError("If out is None, dtype must be specified")
out = torch.empty_like(a, dtype=dtype, device=a.device)
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(n, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
a,
out,
quant_state_code,
absmax,
n,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_8bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
tl.store(out_ptr + offsets, quantized, mask=mask)
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
n = A.numel()
blocks = -(n // -blocksize)
if absmax is None:
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
if out is None:
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_8bit_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
# num_warps=1,
# num_stages=2,
)
out = out.reshape(A.shape)
return out, absmax
@triton.jit
def quantize_8bit_blockwise_kernel_util(
a,
code_ptr,
CODE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_PER_TH: tl.constexpr,
):
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
# Calculating absmax for each block
absmax = tl.max(tl.abs(a_reshaped), axis=1)
a_normalized = a_reshaped / absmax[:, None]
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
for _ in range(8):
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = a_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(a_normalized - lower_val)
upper_dist = tl.abs(a_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
return quantized_flat, absmax
@triton.jit
def dequant_8bit_blockwise_kernel_util(
a_ptr,
offsets,
code_ptr,
absmax_ptr,
mask,
BLOCK_SIZE: tl.constexpr,
):
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
scaled_int8 = tl.load(code_ptr + a, mask)
# Load scales
absmax_offsets = offsets // BLOCK_SIZE
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
# Apply scales
out_dq = scaled_int8 * absmax
return out_dq
This diff is collapsed.
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from . import kernels_optim, triton_kernels from . import kernels_4bit, kernels_8bit_quant, kernels_optim
# currently codes unused, kept for reference # currently codes unused, kept for reference
# Should be the same for quant/dequant # Should be the same for quant/dequant
...@@ -17,19 +17,9 @@ torch_accelerator_module = getattr(torch, device_type, torch.cuda) ...@@ -17,19 +17,9 @@ torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
return out, absmax.float()
out = out.reshape(A.shape)
return out, absmax.float()
def dequantize_blockwise( def dequantize_blockwise(
...@@ -38,22 +28,24 @@ def dequantize_blockwise( ...@@ -38,22 +28,24 @@ def dequantize_blockwise(
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise( out = kernels_8bit_quant.dequant_8bit_blockwise(
A, A,
code,
absmax, absmax,
out, code,
blocksize, blocksize,
dtype=dtype,
) )
return out return out
def dequantize_blockwise_inplace( def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None: ) -> None:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
...@@ -62,12 +54,13 @@ def dequantize_blockwise_inplace( ...@@ -62,12 +54,13 @@ def dequantize_blockwise_inplace(
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.dequant_int8_blockwise( kernels_8bit_quant.dequant_8bit_blockwise(
A, A,
code,
absmax, absmax,
out, code,
blocksize, blocksize,
dtype=dtype,
out=out,
) )
...@@ -92,7 +85,7 @@ def quantize_4bit( ...@@ -92,7 +85,7 @@ def quantize_4bit(
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels.quantize_4bit_blockwise_triton( kernels_4bit.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
) )
packed = out packed = out
...@@ -126,9 +119,8 @@ def dequantize_4bit( ...@@ -126,9 +119,8 @@ def dequantize_4bit(
A = A.squeeze().view(torch.uint8).unsqueeze(1) A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device) out = torch.empty(shape, dtype=dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out return out
...@@ -145,7 +137,7 @@ def dequantize_4bit_inplace( ...@@ -145,7 +137,7 @@ def dequantize_4bit_inplace(
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit( def gemv_4bit(
...@@ -162,7 +154,7 @@ def gemv_4bit( ...@@ -162,7 +154,7 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
with torch_accelerator_module.device(A.device): with torch_accelerator_module.device(A.device):
triton_kernels._dequantize_4bit_impl_passing_code( kernels_4bit.dequantize_4bit_impl_passing_code(
B, B,
absmax, absmax,
blocksize, blocksize,
...@@ -171,11 +163,95 @@ def gemv_4bit( ...@@ -171,11 +163,95 @@ def gemv_4bit(
out=B_dq_triton, out=B_dq_triton,
) )
return torch.nn.functional.linear( return torch.nn.functional.linear(
A, A,
B_dq_triton, B_dq_triton,
bias=None, bias=None,
) )
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms
# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms
# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms
optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam
def optimizer_update_8bit_blockwise(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
with torch_accelerator_module.device(state1.device):
optimizer_update_8bit_blockwise_impl(
optimizer_name=optimizer_name,
g=g,
p=p,
state1=state1,
state2=state2,
beta1=beta1,
beta2=beta2,
beta3=beta3,
alpha=alpha,
eps=eps,
step=step,
lr=lr,
qmap1=qmap1,
qmap2=qmap2,
absmax1=absmax1,
absmax2=absmax2,
weight_decay=weight_decay,
gnorm_scale=gnorm_scale,
skip_zeros=skip_zeros,
)
def optimizer_update_32bit( def optimizer_update_32bit(
......
...@@ -156,6 +156,10 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary): ...@@ -156,6 +156,10 @@ if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(
triton_ops.optimizer_update_8bit_blockwise
)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
@register_kernel("bitsandbytes::dequantize_4bit", "xpu") @register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _( def _(
...@@ -232,6 +236,7 @@ elif triton_available: ...@@ -232,6 +236,7 @@ elif triton_available:
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else: else:
logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.") logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment