Unverified Commit a23026c8 authored by Dmitrii Makarenko's avatar Dmitrii Makarenko Committed by GitHub
Browse files

[Triton/XPU] Support 4bit dequantization logic on Triton (#1629)



* [xpu/triton] Add trtion dequantization kernel

This PR adds xpu backend and trtion kernel for dequantization nf4 dtype.
Trtion is an optional import.
Tests:
	tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases
	tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional
implemented quantize_blockwise with binary search that works faster for XPU
        tests/test_linear4bit.py
Signed-off-by: default avatarDmitrii Makarenko <dmitrii.makarenko@intel.com>

* align with ipex code

* enable test for ipex

* test_kbit_backprop: skip no longer needed

* remove unused

---------
Signed-off-by: default avatarDmitrii Makarenko <dmitrii.makarenko@intel.com>
parent d9333aa9
from collections.abc import Sequence
import torch
from . import triton_kernels
# currently codes unused, kept for reference
# Should be the same for quant/dequant
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape)
return out, absmax.float()
def dequantize_blockwise(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
return out
def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
def quantize_4bit(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
# TODO: Support when weight matrix is not divisible by blocksize
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
blocks = -(n // -(blocksize * 2))
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out
if quant_storage != torch.uint8:
packed = out.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
def dequantize_4bit(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
# torch._check(
# A.dtype == torch.uint8,
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
# )
# Check if this is fine and fast
if A.dtype != torch.uint8:
A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device)
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
def dequantize_4bit_inplace(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
if B.dtype != torch.uint8:
B = B.squeeze().view(torch.uint8).unsqueeze(1)
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
code,
dtype=A.dtype,
out=B_dq_triton,
)
return torch.nn.functional.linear(
A,
B_dq_triton,
bias=None,
)
import torch
import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_SIZE': 64}),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128}),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_SIZE": 256}),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# triton.Config({"SPLIT_SIZE": 512}),
# # triton.Config({'SPLIT_SIZE': 1024}),
# ],
# key=["num_paired_elements", "QUANT_BLOCK"],
# )
@triton.jit
def dequant_8bit_kernel(
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask)
a = a.to(tl.uint8)
# apply conversion
scaled_int8 = tl.load(quant_ptr + a, mask)
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
abs_offsets = offsets // QUANT_BLOCK
mask_blocked = offsets < abs_blocks_lim
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
# apply scales
out_dq = scaled_int8 * absmax
offs = block_start + tl.arange(0, SPLIT_SIZE)
mask = offs < num_paired_elements
tl.store(c_ptr + offs, out_dq, mask)
def dequant_int8_blockwise(
A_nf4: torch.Tensor,
quant_state_code: torch.Tensor,
absmax: torch.Tensor,
out: torch.Tensor,
quant_blocksize: int = 64,
):
number_of_paired_elements = A_nf4.numel()
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_8bit_kernel[grid](
A_nf4,
out,
quant_state_code,
absmax,
number_of_paired_elements,
quant_blocksize,
SPLIT_SIZE,
)
return out
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
# too slow approach
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
n = A.numel()
split_num_blocks = 1
grid = (triton.cdiv(blocks, split_num_blocks),)
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
quantize_blockwise_kernel[grid](
A_ptr=A,
code_ptr=code,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=n,
BLOCK_SIZE=blocksize,
CODE_SIZE=code.numel(),
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# triton.Config({"SPLIT_NUM_BLOCKS": 4}),
# triton.Config({"SPLIT_NUM_BLOCKS": 8}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_fp4_blockwise_kernel(
A_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
sign = tl.where(A_normalized < 0, 0b1000, 0b0000)
A_absf = tl.abs(A_normalized)
result = tl.where(
A_absf > 0.29166667,
tl.where(
A_absf > 0.583333, tl.where(A_absf > 0.8333333, 0b011, 0b010), tl.where(A_absf > 0.4166667, 0b101, 0b100)
),
tl.where(
A_absf > 0.0859375,
tl.where(A_absf > 0.20833333, 0b0111, 0b0110),
tl.where(A_absf > 0.00260417, 0b0001, 0b0000),
),
)
quantized = (result ^ sign).to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeNF4
# @triton.autotune(
# configs=[
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# triton.Config({"SPLIT_NUM_BLOCKS": 4}),
# triton.Config({"SPLIT_NUM_BLOCKS": 8}),
# ],
# key=["n_elements"],
# )
@triton.jit
def quantize_nf4_blockwise_kernel(
A_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
result = tl.where(
A_normalized > 0.03979014977812767,
tl.where(
A_normalized > 0.3893125355243683,
tl.where(
A_normalized > 0.6427869200706482,
tl.where(A_normalized > 0.8614784181118011, 0b1111, 0b1110),
tl.where(A_normalized > 0.5016634166240692, 0b1101, 0b1100),
),
tl.where(
A_normalized > 0.2035212516784668,
tl.where(A_normalized > 0.2920137718319893, 0b1011, 0b1010),
tl.where(A_normalized > 0.1202552504837513, 0b1001, 0b1000),
),
),
tl.where(
A_normalized > -0.33967943489551544,
tl.where(
A_normalized > -0.13791173323988914,
tl.where(A_normalized > -0.045525018125772476, 0b0111, 0b0110),
tl.where(A_normalized > -0.23460740596055984, 0b0101, 0b0100),
),
tl.where(
A_normalized > -0.6106329262256622,
tl.where(A_normalized > -0.4599952697753906, 0b0011, 0b0010),
tl.where(A_normalized > -0.8480964004993439, 0b0001, 0b0000),
),
),
)
quantized = result.to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
def quantize_4bit_blockwise_triton(A, blocksize, quant_type, blocks, absmax, num_elements, quantized_out):
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
split_num_blocks = 4
grid = (triton.cdiv(blocks, split_num_blocks),)
if quant_type == "fp4":
quantize_fp4_blockwise_kernel[grid](
A_ptr=A,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=num_elements,
BLOCK_SIZE=blocksize,
SPLIT_NUM_BLOCKS=split_num_blocks,
)
else:
quantize_nf4_blockwise_kernel[grid](
A_ptr=A,
absmax_ptr=absmax,
out_ptr=quantized_out,
n_elements=num_elements,
BLOCK_SIZE=blocksize,
SPLIT_NUM_BLOCKS=split_num_blocks,
)
return quantized_out, absmax
@triton.jit
def dequant_4bit_body_util(a, offsets, quant_ptr, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
# lower 4bits
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
# apply conversion
lower_4 = tl.load(quant_ptr + lower, eviction_policy="evict_last")
higher_4 = tl.load(quant_ptr + higher, eviction_policy="evict_last")
mul_high = higher_4 * absmax
mul_low = lower_4 * absmax
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeFP4Tree
@triton.jit
def dequantize_fp4_tree(val, absmax):
# val: tl.tensor (uint8)
# absmax: tl.tensor (float32/float16)
# 00001100 00001011 00001001 00001111
sign = tl.where((val & 0b1000) == 0b1000, -1.0, 1.0) # -1
third_bit = (val & 0b0100) == 0b0100 # True
second_bit = (val & 0b0010) == 0b0010 # False
first_bit = (val & 0b0001) == 0b0001 # False
branch1 = tl.where(
second_bit,
tl.where(first_bit, 0.25, 0.16666667), # 1111, 1110
tl.where(first_bit, 0.5, 0.33333333), # 1101, 1100
)
branch2 = tl.where(
second_bit,
tl.where(first_bit, 1.0, 0.66666667), # 1011, 1010
tl.where(first_bit, 0.00520833, 0.0), # 1001, 1000
)
out = tl.where(third_bit, branch1, branch2)
return out * sign * absmax
@triton.jit
def dequant_fp4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
mul_high = dequantize_fp4_tree(higher, absmax)
mul_low = dequantize_fp4_tree(lower, absmax)
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dDequantizeNF4
@triton.jit
def dequantize_nf4_tree(val):
# val: tl.tensor (uint8)
cond0 = (val & 0b1000) == 0b1000
cond1 = (val & 0b0100) == 0b0100
cond2 = (val & 0b0010) == 0b0010
cond3 = (val & 0b0001) == 0b0001
# Positive branch (val & 0b1000) == 8
branch_pos = tl.where(
cond1,
tl.where(
cond2,
tl.where(cond3, 1.0, 0.7229568362236023), # 1111, 1110
tl.where(cond3, 0.5626170039176941, 0.44070982933044434), # 1101, 1100
),
tl.where(
cond2,
tl.where(cond3, 0.33791524171829224, 0.24611230194568634), # 1011, 1010
tl.where(cond3, 0.16093020141124725, 0.07958029955625534), # 1001, 1000
),
)
# Negative branch (val & 0b1000) == 0
branch_neg = tl.where(
cond1,
tl.where(
cond2,
tl.where(cond3, 0.0, -0.09105003625154495), # 0111, 0110
tl.where(cond3, -0.18477343022823334, -0.28444138169288635), # 0101, 0100
),
tl.where(
cond2,
tl.where(cond3, -0.39491748809814453, -0.5250730514526367), # 0011, 0010
tl.where(cond3, -0.6961928009986877, -1.0), # 0001, 0000
),
)
return tl.where(cond0, branch_pos, branch_neg)
@triton.jit
def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.constexpr):
PAIRED_QUANT_BLOCK: tl.constexpr = QUANT_BLOCK // 2
mask = offsets < n_elems
higher = a & 0xF
# lower 4bits
lower = a >> 4
abs_offsets = offsets // PAIRED_QUANT_BLOCK
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=1.0, eviction_policy="evict_last")
mul_high = dequantize_nf4_tree(higher) * absmax
mul_low = dequantize_nf4_tree(lower) * absmax
out_dq = tl.interleave(mul_low, mul_high)
return out_dq
# All such kernels are similar, so maybe code can be generalised.
# @triton.autotune(
# configs=[
# # # triton.Config({'SPLIT_SIZE': 64}),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({'SPLIT_SIZE': 128}),
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# # # triton.Config({'SPLIT_SIZE': 128}, num_warps = 4, num_stages = 4),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# # triton.Config({'SPLIT_SIZE': 256}, num_warps = 4, num_stages = 4),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# # triton.Config({'SPLIT_SIZE': 512}, num_warps = 4, num_stages = 4),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # # # triton.Config({'SPLIT_SIZE': 512, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
# # # triton.Config({'SPLIT_SIZE': 1024}),
# # # # triton.Config({'SPLIT_SIZE': 2048}),
# # # # triton.Config({'SPLIT_SIZE': 4096}),
# # # # triton.Config({'SPLIT_SIZE': 8192}),
# # # # triton.Config({'SPLIT_SIZE': 16384}),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_4bit_kernel(
a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_4bit_body_util(
a=a,
offsets=offsets,
quant_ptr=quant_ptr,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
tl.store(c_ptr + offs, out_dq, mask)
# @triton.autotune(
# configs=[
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_fp4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_fp4_body_util(
a=a,
offsets=offsets,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
tl.store(c_ptr + offs, out_dq, mask)
# @triton.autotune(
# configs=[
# triton.Config({'SPLIT_SIZE': 128}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 256}),
# triton.Config({'SPLIT_SIZE': 256}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 512}),
# triton.Config({'SPLIT_SIZE': 512}, num_warps = 32, num_stages = 2),
# triton.Config({'SPLIT_SIZE': 1024}, num_warps = 32, num_stages = 2),
# ],
# key=['num_paired_elements'],
# )
@triton.jit
def dequant_nf4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
offsets = block_start + tl.arange(0, SPLIT_SIZE)
mask = offsets < num_paired_elements
a = tl.load(a_ptr + offsets, mask, eviction_policy="evict_first")
out_dq = dequant_nf4_body_util(
a=a,
offsets=offsets,
absmax_ptr=absmax_ptr,
n_elems=num_paired_elements,
QUANT_BLOCK=QUANT_BLOCK,
)
out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
tl.store(c_ptr + offs, out_dq, mask)
def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
# It's will be processed as an array, so
# actual length is row * col
# Elements are in uint8 format, so interleaved
# so total amount of data is 2 * elem_count
number_of_paired_elements = A.numel()
# we assume that split_size > quant_blocksize
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
if quant_type == "fp4":
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
else:
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
def _dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
code: torch.Tensor,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
number_of_paired_elements = A.numel()
# we assume that split_size > quant_blocksize
SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
######################### Fallback dequantization functions #########################
## for debug ##
# @triton.autotune(
# configs=[
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 1, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # #
# # triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# #
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "large"}, num_stages=2, num_warps=32),
# # # triton.Config({'SPLIT_NUM_BLOCKS': 2, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=2, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=2, num_warps=32),
# # triton.Config({"SPLIT_NUM_BLOCKS": 4, "grf_mode": "large"}, num_stages=4, num_warps=32),
# # triton.Config({'SPLIT_NUM_BLOCKS': 8, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
# ],
# key=["n_elements", "BLOCK_SIZE"],
# )
@triton.jit
def quantize_4bit_blockwise_kernel(
A_ptr,
code_ptr,
absmax_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
CODE_SIZE: tl.constexpr,
SPLIT_NUM_BLOCKS: tl.constexpr,
):
PAIRED_SPLIT_NUM_BLOCKS: tl.constexpr = SPLIT_NUM_BLOCKS * 2
block_start_idx = tl.program_id(0) * PAIRED_SPLIT_NUM_BLOCKS
thread_idx = tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS * BLOCK_SIZE)
offsets = block_start_idx * BLOCK_SIZE + thread_idx
mask = offsets < n_elements
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
# To be able process several blocks -> (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE)
A_reshaped = tl.reshape(A, (PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE))
# Calculating absamax for each block
absmax = tl.max(tl.abs(A_reshaped), axis=1)
tl.store(absmax_ptr + block_start_idx + tl.arange(0, PAIRED_SPLIT_NUM_BLOCKS), absmax)
A_normalized = A_reshaped / absmax[:, None]
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
lower_pivot = tl.zeros((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
upper_pivot = tl.full((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
for _ in range(4): # ceil(log2(code_size)) = 4, actually, in general case should be input parameter
pivot = (lower_pivot + upper_pivot) // 2
val = tl.load(code_ptr + pivot)
is_higher = A_normalized > val # code[pivot]
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
# Choose closest level
lower_val = tl.load(code_ptr + lower_pivot)
upper_val = tl.load(code_ptr + upper_pivot)
lower_dist = tl.abs(A_normalized - lower_val)
upper_dist = tl.abs(A_normalized - upper_val)
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
quantized = quantized.reshape((PAIRED_SPLIT_NUM_BLOCKS, BLOCK_SIZE // 2, 2))
quantized = quantized.to(tl.uint8, bitcast=True)
left, right = quantized.split()
packed = left << 4 | (right & 0xF)
# Reduce don't guarantee the order of the elements passed to unite_2_int4
# packed = tl.reduce(quantized, axis=2, combine_fn=unite_2_int4)
# packed = packed.to(tl.uint8, bitcast=True)
packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)
...@@ -13,6 +13,15 @@ except BaseException: ...@@ -13,6 +13,15 @@ except BaseException:
ipex_cpu = None ipex_cpu = None
ipex_xpu = None ipex_xpu = None
try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
triton_available = True
except ImportError as e:
triton_available = False
_NF4_QUANT_TABLE = torch.tensor( _NF4_QUANT_TABLE = torch.tensor(
[ [
-1.0, -1.0,
......
from collections.abc import Sequence from collections.abc import Sequence
import warnings
import torch import torch
from ..._ops import register_kernel from ..._ops import register_kernel
from ..utils import ipex_xpu from ..utils import ipex_xpu, triton_available
if torch.__version__ >= (2, 7): # _int_mm is available in torch starting from 2.7 version,
# but currently it's don't have xpu implementation.
if ipex_xpu and torch.__version__ >= (2, 7):
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu") @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor):
...@@ -15,6 +18,7 @@ if torch.__version__ >= (2, 7): ...@@ -15,6 +18,7 @@ if torch.__version__ >= (2, 7):
).reshape(*A.shape[:-1], B.shape[0]) ).reshape(*A.shape[:-1], B.shape[0])
# IPEX should be faster for xpu, so at first checking if it is available.
if ipex_xpu: if ipex_xpu:
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
...@@ -49,3 +53,15 @@ if ipex_xpu: ...@@ -49,3 +53,15 @@ if ipex_xpu:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
return out.reshape(shape) return out.reshape(shape)
elif triton_available:
from ..triton import ops as triton_ops
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace)
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn("XPU available but no ipex or triton packages found.")
...@@ -671,7 +671,7 @@ class Int8Params(torch.nn.Parameter): ...@@ -671,7 +671,7 @@ class Int8Params(torch.nn.Parameter):
if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8: if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device) return self._quantize(device)
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"): elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
self.CB = self.data self.CB = self.data
new_param = Int8Params( new_param = Int8Params(
......
...@@ -137,11 +137,11 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -137,11 +137,11 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = sum(diffs) / len(diffs) abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs) relerr = sum(reldiffs) / len(reldiffs)
if signed: if signed:
threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035 threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
assert abserr < 0.0036 assert abserr < 0.0036
assert relerr < 0.015 assert relerr < 0.015
else: else:
assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023 assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
assert relerr < 0.012 assert relerr < 0.012
assert A2.dtype == dtype assert A2.dtype == dtype
...@@ -172,7 +172,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -172,7 +172,7 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
def test_few_bit_quant(self, device, bits, method): def test_few_bit_quant(self, device, bits, method):
if device in ("cpu", "xpu") and bits != 8: if device in ("cpu", "xpu") and bits != 8 and (F.ipex_cpu or F.ipex_xpu):
pytest.skip("CPU/XPU implementation only supports 8 bits") pytest.skip("CPU/XPU implementation only supports 8 bits")
abserrs = [] abserrs = []
......
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import torch import torch
import bitsandbytes import bitsandbytes
from bitsandbytes.functional import ipex_xpu
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
# torch.library.opcheck is only available in torch 2.4 and later. # torch.library.opcheck is only available in torch 2.4 and later.
...@@ -144,7 +145,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -144,7 +145,7 @@ class TestInt8BlockwiseQuantOps:
assert out.device == A.device assert out.device == A.device
# TODO: Enable it # TODO: Enable it
if device == "xpu": if device == "xpu" and ipex_xpu:
pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
...@@ -170,7 +171,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -170,7 +171,7 @@ class Test4bitBlockwiseQuantOps:
if storage_dtype != torch.uint8: if storage_dtype != torch.uint8:
pytest.xfail("opcheck fails for storage_dtype != torch.uint8") pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment