Unverified Commit a23026c8 authored by Dmitrii Makarenko's avatar Dmitrii Makarenko Committed by GitHub
Browse files

[Triton/XPU] Support 4bit dequantization logic on Triton (#1629)



* [xpu/triton] Add trtion dequantization kernel

This PR adds xpu backend and trtion kernel for dequantization nf4 dtype.
Trtion is an optional import.
Tests:
	tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases
	tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional
implemented quantize_blockwise with binary search that works faster for XPU
        tests/test_linear4bit.py
Signed-off-by: default avatarDmitrii Makarenko <dmitrii.makarenko@intel.com>

* align with ipex code

* enable test for ipex

* test_kbit_backprop: skip no longer needed

* remove unused

---------
Signed-off-by: default avatarDmitrii Makarenko <dmitrii.makarenko@intel.com>
parent d9333aa9
from collections.abc import Sequence
import torch
from . import triton_kernels
# currently codes unused, kept for reference
# Should be the same for quant/dequant
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
out = out.reshape(A.shape)
return out, absmax.float()
def dequantize_blockwise(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
out = torch.empty_like(A, dtype=dtype, device=A.device)
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
return out
def dequantize_blockwise_inplace(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels.dequant_int8_blockwise(
A,
code,
absmax,
out,
blocksize,
)
def quantize_4bit(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
# TODO: Support when weight matrix is not divisible by blocksize
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
blocks = -(n // -(blocksize * 2))
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
triton_kernels.quantize_4bit_blockwise_triton(
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
)
packed = out
if quant_storage != torch.uint8:
packed = out.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
def dequantize_4bit(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
# torch._check(
# A.dtype == torch.uint8,
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
# )
# Check if this is fine and fast
if A.dtype != torch.uint8:
A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device)
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
def dequantize_4bit_inplace(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
if B.dtype != torch.uint8:
B = B.squeeze().view(torch.uint8).unsqueeze(1)
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
triton_kernels._dequantize_4bit_impl_passing_code(
B,
absmax,
blocksize,
code,
dtype=A.dtype,
out=B_dq_triton,
)
return torch.nn.functional.linear(
A,
B_dq_triton,
bias=None,
)
This diff is collapsed.
...@@ -13,6 +13,15 @@ except BaseException: ...@@ -13,6 +13,15 @@ except BaseException:
ipex_cpu = None ipex_cpu = None
ipex_xpu = None ipex_xpu = None
try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
triton_available = True
except ImportError as e:
triton_available = False
_NF4_QUANT_TABLE = torch.tensor( _NF4_QUANT_TABLE = torch.tensor(
[ [
-1.0, -1.0,
......
from collections.abc import Sequence from collections.abc import Sequence
import warnings
import torch import torch
from ..._ops import register_kernel from ..._ops import register_kernel
from ..utils import ipex_xpu from ..utils import ipex_xpu, triton_available
if torch.__version__ >= (2, 7): # _int_mm is available in torch starting from 2.7 version,
# but currently it's don't have xpu implementation.
if ipex_xpu and torch.__version__ >= (2, 7):
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu") @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor):
...@@ -15,6 +18,7 @@ if torch.__version__ >= (2, 7): ...@@ -15,6 +18,7 @@ if torch.__version__ >= (2, 7):
).reshape(*A.shape[:-1], B.shape[0]) ).reshape(*A.shape[:-1], B.shape[0])
# IPEX should be faster for xpu, so at first checking if it is available.
if ipex_xpu: if ipex_xpu:
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
...@@ -49,3 +53,15 @@ if ipex_xpu: ...@@ -49,3 +53,15 @@ if ipex_xpu:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
return out.reshape(shape) return out.reshape(shape)
elif triton_available:
from ..triton import ops as triton_ops
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace)
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
else:
warnings.warn("XPU available but no ipex or triton packages found.")
...@@ -671,7 +671,7 @@ class Int8Params(torch.nn.Parameter): ...@@ -671,7 +671,7 @@ class Int8Params(torch.nn.Parameter):
if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8: if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device) return self._quantize(device)
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"): elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
self.CB = self.data self.CB = self.data
new_param = Int8Params( new_param = Int8Params(
......
...@@ -137,11 +137,11 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -137,11 +137,11 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = sum(diffs) / len(diffs) abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs) relerr = sum(reldiffs) / len(reldiffs)
if signed: if signed:
threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035 threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
assert abserr < 0.0036 assert abserr < 0.0036
assert relerr < 0.015 assert relerr < 0.015
else: else:
assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023 assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
assert relerr < 0.012 assert relerr < 0.012
assert A2.dtype == dtype assert A2.dtype == dtype
...@@ -172,7 +172,7 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -172,7 +172,7 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
def test_few_bit_quant(self, device, bits, method): def test_few_bit_quant(self, device, bits, method):
if device in ("cpu", "xpu") and bits != 8: if device in ("cpu", "xpu") and bits != 8 and (F.ipex_cpu or F.ipex_xpu):
pytest.skip("CPU/XPU implementation only supports 8 bits") pytest.skip("CPU/XPU implementation only supports 8 bits")
abserrs = [] abserrs = []
......
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import torch import torch
import bitsandbytes import bitsandbytes
from bitsandbytes.functional import ipex_xpu
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
# torch.library.opcheck is only available in torch 2.4 and later. # torch.library.opcheck is only available in torch 2.4 and later.
...@@ -144,7 +145,7 @@ class TestInt8BlockwiseQuantOps: ...@@ -144,7 +145,7 @@ class TestInt8BlockwiseQuantOps:
assert out.device == A.device assert out.device == A.device
# TODO: Enable it # TODO: Enable it
if device == "xpu": if device == "xpu" and ipex_xpu:
pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
...@@ -170,7 +171,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -170,7 +171,7 @@ class Test4bitBlockwiseQuantOps:
if storage_dtype != torch.uint8: if storage_dtype != torch.uint8:
pytest.xfail("opcheck fails for storage_dtype != torch.uint8") pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment