Unverified Commit 1813b058 authored by Liu Xiaoli's avatar Liu Xiaoli Committed by GitHub
Browse files

Add SYCL Kernels for XPU backend (#1679)



* Add SYCL Kernels for XPU backend

* fix transpose
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix log and format
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* revert cpu changes
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean ipex_xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean ipex import
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix ipex cpu import
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix typo
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix comments
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* refine gemv_4bit kernel

* enable FP4 for dequant_4bit and gemv_4bit

* refine FP4 dequantization performance

* remove check for better performance
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix doc
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean code

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* rm comments
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix memory issue

* fix ut failure

* adjust threshold
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix xpu check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* change test_functional check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix test_module
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix device check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Enable Windows build and refine code

* fix xpu log
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* remove ipex entirely
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix cpu int8 CB
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix lint
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix logs (#12)

* fix logs
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix format
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Fix sycl lint error and tests (#13)

* fix sycl nd
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip typo check for xpu kernel codes (#14)

* skip test for xpu ops
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix lint
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip typo for xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* register triton kernel for quantization (#15)
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Fix version comparison issue (#18)

# Description

The version comparison expression miss reference the .release property from the version object. This lead to compare between the tuple and the string

# Error message
```
The 8-bit optimizer is not available on your device, only available on CUDA for now.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Traceback (most recent call last):
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/unsloth_validation/run.py", line 1, in <module>
    import unsloth
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/__init__.py", line 235, in <module>
    from .models import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/__init__.py", line 15, in <module>
    from .llama     import FastLlamaModel
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/llama.py", line 23, in <module>
    from ._utils import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/_utils.py", line 89, in <module>
    from unsloth_zoo.patching_utils import (
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth_zoo/patching_utils.py", line 629, in <module>
    import transformers.integrations.bitsandbytes
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py", line 20, in <module>
    import bitsandbytes as bnb
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/__init__.py", line 39, in <module>
    from .backends.xpu import ops as xpu_ops
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/backends/xpu/ops.py", line 17, in <module>
    if version.parse(torch.__version__).release >= version.parse("2.9"):
TypeError: '>=' not supported between instances of 'tuple' and 'Version'
```

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatarEr-Xin (Edwin) Shang <shangerxin@hotmail.com>
parent 275671be
...@@ -162,7 +162,7 @@ jobs: ...@@ -162,7 +162,7 @@ jobs:
- name: Run tests - name: Run tests
run: pytest --durations=100 run: pytest --durations=100
test-cpu-ipex: test-cpu-intel:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes' if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu needs: build-cpu
runs-on: banb-aws-general-8-plus-use1-public-80 runs-on: banb-aws-general-8-plus-use1-public-80
...@@ -186,7 +186,6 @@ jobs: ...@@ -186,7 +186,6 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install -e ".[test]" pip install -e ".[test]"
pip install pytest-cov pip install pytest-cov
...@@ -196,9 +195,6 @@ jobs: ...@@ -196,9 +195,6 @@ jobs:
- name: Show environment information - name: Show environment information
run: python -m torch.utils.collect_env run: python -m torch.utils.collect_env
- name: IPEX smoke test
run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"
- name: Run tests - name: Run tests
run: pytest --durations=100 run: pytest --durations=100
...@@ -286,15 +282,6 @@ jobs: ...@@ -286,15 +282,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"] torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
ipex: [false]
# ipex: [true, false]
# include:
# - torch_version: "2.6.0"
# ipex: true
# ipex_version: "2.6.10+xpu"
# - torch_version: "2.7.1"
# ipex: true
# ipex_version: "2.7.10+xpu"
runs-on: runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu group: bandb-itac-bmsprpvc1550-8-1gpu
env: env:
...@@ -330,10 +317,6 @@ jobs: ...@@ -330,10 +317,6 @@ jobs:
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu
- name: Install IPEX
if: matrix.ipex == true
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -e ".[test]" pip install -e ".[test]"
......
...@@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) ...@@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm) set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal) set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
# C++ sources are always included # C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES}) list(APPEND SRC_FILES ${CPP_FILES})
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
if(APPLE) if(APPLE)
...@@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps") ...@@ -64,10 +65,18 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF) set(BUILD_HIP OFF)
set(BUILD_MPS ON) set(BUILD_MPS ON)
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
if(APPLE)
message(FATAL_ERROR "XPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
else() else()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF) set(BUILD_HIP OFF)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
endif() endif()
...@@ -217,6 +226,15 @@ elseif(BUILD_MPS) ...@@ -217,6 +226,15 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels" COMMENT "Compiling Metal kernels"
VERBATIM) VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
add_compile_definitions(BUILD_XPU)
set(CMAKE_C_COMPILER icx)
set(CMAKE_CXX_COMPILER icpx)
if(WIN32)
set(CMAKE_CXX_COMPILER icx)
endif()
else() else()
string(APPEND BNB_OUTPUT_NAME "_cpu") string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES) set(GPU_SOURCES)
...@@ -285,6 +303,15 @@ if(BUILD_MPS) ...@@ -285,6 +303,15 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib) add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif() endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})
endif()
if(WIN32) if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
......
[files] [files]
# Skip these files in typo checks
extend-exclude = [
"csrc/xpu_ops.h",
"csrc/xpu_ops.cpp",
"csrc/xpu_kernels.h",
"csrc/xpu_kernels.cpp"
]
[default] [default]
extend-ignore-re = [ extend-ignore-re = [
......
...@@ -4,8 +4,6 @@ from typing import Optional ...@@ -4,8 +4,6 @@ from typing import Optional
import torch import torch
from .cextension import ipex_cpu, ipex_xpu
_IS_TORCH_GTE_24 = False _IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"): if hasattr(torch.library, "register_fake"):
...@@ -331,25 +329,6 @@ def _( ...@@ -331,25 +329,6 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
if ipex_cpu or ipex_xpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_nf4_ipex")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define( torch.library.define(
"bitsandbytes::optimizer_update_32bit", "bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()", "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from typing_extensions import deprecated from typing_extensions import deprecated
import bitsandbytes.functional as F import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
...@@ -320,8 +319,6 @@ class MatMul8bitFp(torch.autograd.Function): ...@@ -320,8 +319,6 @@ class MatMul8bitFp(torch.autograd.Function):
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias) output = torch.nn.functional.linear(A, CB, bias)
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
state.idx = False
ctx.state = state ctx.state = state
ctx.dtype_A = A.dtype ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape ctx.grad_shape = A.shape
...@@ -426,7 +423,7 @@ def matmul( ...@@ -426,7 +423,7 @@ def matmul(
state.threshold = threshold state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training: if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu"): if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
...@@ -440,17 +437,6 @@ def matmul_4bit( ...@@ -440,17 +437,6 @@ def matmul_4bit(
): ):
assert quant_state is not None assert quant_state is not None
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn( warn(
......
from collections.abc import Sequence
import ctypes as ct import ctypes as ct
import logging
import torch import torch
from bitsandbytes.functional import get_ptr from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel from ..._ops import register_kernel
from ...cextension import lib from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import ipex_cpu
logger = logging.getLogger(__name__)
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support. # However, we can overflow if we use this without AVX512_VNNI support.
...@@ -24,8 +25,10 @@ if torch.__version__ >= (2, 6): ...@@ -24,8 +25,10 @@ if torch.__version__ >= (2, 6):
).reshape(*A.shape[:-1], B.shape[0]) ).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
n = A.numel() n = A.numel()
...@@ -66,9 +69,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor ...@@ -66,9 +69,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
return out, absmax return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") def _(
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
...@@ -95,26 +99,3 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, ...@@ -95,26 +99,3 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
out = out.reshape(A.shape) out = out.reshape(A.shape)
return out return out
if ipex_cpu:
from bitsandbytes.utils import _reverse_4bit_compress_format
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
return torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
blocksize,
"nf4",
shape,
dtype,
)
...@@ -3,16 +3,6 @@ import subprocess ...@@ -3,16 +3,6 @@ import subprocess
from packaging import version from packaging import version
import torch import torch
try:
# to support Intel CPU/XPU (IPEX) backend
import intel_extension_for_pytorch as ipex
ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None
try: try:
import triton # noqa: F401 import triton # noqa: F401
import triton.language as tl # noqa: F401 import triton.language as tl # noqa: F401
......
File mode changed from 100755 to 100644
from collections.abc import Sequence from collections.abc import Sequence
import warnings import ctypes as ct
import logging
from packaging import version from packaging import version
import torch import torch
from bitsandbytes.functional import _get_tensor_stream, get_ptr
from ..._ops import register_kernel from ..._ops import register_kernel
from ..utils import ipex_xpu, triton_available from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
from ..utils import triton_available
logger = logging.getLogger(__name__)
# _int_mm is available in torch starting from 2.9 version, or ipex 2.7 # _int_mm is available in torch starting from 2.9 version
if version.parse(torch.__version__).release >= version.parse("2.9").release or ( if version.parse(torch.__version__).release >= version.parse("2.9").release:
ipex_xpu and torch.__version__ >= (2, 7)
):
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu") @register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor):
...@@ -20,42 +24,205 @@ if version.parse(torch.__version__).release >= version.parse("2.9").release or ( ...@@ -20,42 +24,205 @@ if version.parse(torch.__version__).release >= version.parse("2.9").release or (
).reshape(*A.shape[:-1], B.shape[0]) ).reshape(*A.shape[:-1], B.shape[0])
# IPEX should be faster for xpu, so at first checking if it is available. def _dequantize_4bit_impl(
if ipex_xpu: A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
args = (
None,
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
_get_tensor_stream(A),
)
if dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(*args)
elif dtype == torch.float16:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif dtype == torch.float32:
if quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(A.numel()),
_get_tensor_stream(A),
)
if dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)
def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
m = ct.c_int32(1)
n = ct.c_int32(shapeB[0])
k = ct.c_int32(shapeB[1])
lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldc = m
stream = _get_tensor_stream(A)
if A.dtype == torch.float16:
lib.cgemv_4bit_inference_fp16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.bfloat16:
lib.cgemv_4bit_inference_bf16(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
elif A.dtype == torch.float32:
lib.cgemv_4bit_inference_fp32(
m,
n,
k,
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(code),
get_ptr(out),
lda,
ldb,
ldc,
ct.c_int32(blocksize),
stream,
)
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") # SYCL should be faster for xpu, so at first checking if it is available.
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
logger.info("Register sycl bitsandbytes kernels for XPU")
# TODO: Remove the triton register when quantization sycl kernel is ready.
if triton_available:
from ..triton import ops as triton_ops
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
@register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _( def _(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
blocksize: int, blocksize: int,
quant_type: str,
shape: Sequence[int], shape: Sequence[int],
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_blockwise", "xpu") @register_kernel("bitsandbytes::dequantize_blockwise", "xpu")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
out = torch.empty_like(A, dtype=dtype)
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
return out
@register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")
def _( def _(
A: torch.Tensor, A: torch.Tensor,
absmax: torch.Tensor, absmax: torch.Tensor,
code: torch.Tensor, code: torch.Tensor,
blocksize: int, blocksize: int,
dtype: torch.dtype, dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
_dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
@register_kernel("bitsandbytes::gemv_4bit", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor: ) -> torch.Tensor:
shape = A.shape shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) out = torch.empty(shape, device=A.device, dtype=A.dtype)
# void cdequantize_blockwise_fp32( _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) return out
if dtype == torch.float16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.bfloat16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.float32:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
return out.reshape(shape) @register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
elif triton_available: elif triton_available:
logger.info("Register triton bitsandbytes kernels for XPU")
from ..triton import ops as triton_ops from ..triton import ops as triton_ops
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
...@@ -67,4 +234,4 @@ elif triton_available: ...@@ -67,4 +234,4 @@ elif triton_available:
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit) register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else: else:
warnings.warn("XPU available but no ipex or triton packages found.") logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
...@@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary:
binary_path = cuda_binary_path binary_path = cuda_binary_path
if torch._C._has_xpu:
binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}"
logger.debug(f"Loading bitsandbytes native library from: {binary_path}") logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
# Try to load the library - any errors will propagate up # Try to load the library - any errors will propagate up
...@@ -299,28 +302,25 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -299,28 +302,25 @@ def get_native_library() -> BNBNativeLibrary:
ROCM_GPU_ARCH = get_rocm_gpu_arch() ROCM_GPU_ARCH = get_rocm_gpu_arch()
try: HIP_ENVIRONMENT = False
# to support Intel CPU/GPU (XPU) backend BNB_BACKEND = "CPU"
import intel_extension_for_pytorch as ipex if torch.version.hip:
HIP_ENVIRONMENT = True
ipex_cpu = ipex if ipex._C._has_cpu() else None BNB_BACKEND = "ROCm"
ipex_xpu = ipex if ipex._C._has_xpu() else None elif torch.cuda.is_available():
except BaseException: BNB_BACKEND = "CUDA"
ipex_cpu = None elif torch._C._has_xpu:
ipex_xpu = None BNB_BACKEND = "XPU"
try: try:
if torch.version.hip:
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
else:
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
lib = get_native_library() lib = get_native_library()
except Exception as e: except Exception as e:
if BNB_BACKEND in ("CPU", "XPU"):
lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.")
else:
error_msg = str(e) error_msg = str(e)
if not (ipex_cpu or ipex_xpu):
logger.error( logger.error(
f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops", f"bitsandbytes library load error: {error_msg}",
exc_info=True, exc_info=True,
) )
......
...@@ -13,9 +13,9 @@ import torch ...@@ -13,9 +13,9 @@ import torch
from torch import Tensor from torch import Tensor
from typing_extensions import deprecated from typing_extensions import deprecated
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib from .cextension import HIP_ENVIRONMENT, lib
name2qmap = {} name2qmap = {}
...@@ -370,6 +370,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): ...@@ -370,6 +370,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons. # We use the raw stream for performance reasons.
if tensor.device.type == "xpu":
return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
...@@ -984,16 +986,6 @@ def dequantize_4bit( ...@@ -984,16 +986,6 @@ def dequantize_4bit(
if absmax.dtype != torch.float32: if absmax.dtype != torch.float32:
absmax = absmax.float() absmax = absmax.float()
# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
return torch.ops.bitsandbytes.dequantize_nf4_ipex(
A,
absmax,
quant_state.blocksize,
quant_state.shape,
quant_state.dtype,
)
if out is not None: if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out( torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
...@@ -1530,25 +1522,6 @@ def gemv_4bit( ...@@ -1530,25 +1522,6 @@ def gemv_4bit(
if state.nested: if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset absmax = dequantize_blockwise(absmax, state.state2) + state.offset
if getattr(state, "ipex", False) and state.quant_type == "nf4":
# compute_dtype: 1 indicates fp16, 2 indicates bf16
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
out = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
compute_dtype,
1,
state.compensation,
)
return out
if out is not None: if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out( torch.ops.bitsandbytes.gemv_4bit.out(
A, A,
...@@ -2227,49 +2200,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2227,49 +2200,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0 C = 127.0
def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")
if x.device.type == "cpu" and ipex_cpu:
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
elif x.device.type == "xpu" and ipex_xpu:
new_weight = _reverse_4bit_compress_format(linear.weight.data)
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
new_scales = list(new_scales)
if not linear.training and not x.requires_grad:
new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7"
)
linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
linear.weight.quant_state.new_zeros = new_zeros
linear.weight.quant_state.compensation = compensation
...@@ -12,13 +12,9 @@ import torch.nn.functional as F ...@@ -12,13 +12,9 @@ import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
_reverse_4bit_compress_format,
)
T = TypeVar("T", bound="torch.nn.Module") T = TypeVar("T", bound="torch.nn.Module")
...@@ -483,7 +479,6 @@ class Linear4bit(nn.Linear): ...@@ -483,7 +479,6 @@ class Linear4bit(nn.Linear):
self.compute_type_is_set = compute_dtype is not None self.compute_type_is_set = compute_dtype is not None
self.quant_state = None self.quant_state = None
self.quant_storage = quant_storage self.quant_storage = quant_storage
self.ipex_linear_is_set = False
def set_compute_type(self, x): def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]: if x.dtype in [torch.float32, torch.bfloat16]:
...@@ -510,40 +505,13 @@ class Linear4bit(nn.Linear): ...@@ -510,40 +505,13 @@ class Linear4bit(nn.Linear):
save weight and bias, save weight and bias,
then fill state_dict with components of quant_state then fill state_dict with components of quant_state
""" """
if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
if self.weight.device.type == "cpu":
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = _reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
self.weight.quant_state.ipex = False
self.ipex_linear_is_set = False
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None: if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items(): for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach() destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def set_ipex_linear(self, x: torch.Tensor):
if (
not getattr(self.weight.quant_state, "ipex", False)
and self.weight.data.dtype == torch.uint8
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
):
if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
_enable_ipex_fusion(self, x)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu):
self.set_ipex_linear(x)
self.ipex_linear_is_set = True
fix_4bit_weight_quant_state_from_module(self) fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
...@@ -559,8 +527,7 @@ class Linear4bit(nn.Linear): ...@@ -559,8 +527,7 @@ class Linear4bit(nn.Linear):
x = x.to(self.compute_dtype) x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype)
# IPEX CPU will change weight to 4D so don't need transpose weight = self.weight.t()
weight = self.weight.t() if self.weight.dim() == 2 else self.weight
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
...@@ -715,7 +682,7 @@ class Int8Params(torch.nn.Parameter): ...@@ -715,7 +682,7 @@ class Int8Params(torch.nn.Parameter):
if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8: if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device) return self._quantize(device)
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): elif self.data.dtype == torch.int8 and device.type == "cpu":
self.CB = self.data self.CB = self.data
new_param = Int8Params( new_param = Int8Params(
......
...@@ -38,14 +38,6 @@ def outlier_hook(module, input): ...@@ -38,14 +38,6 @@ def outlier_hook(module, input):
hook.remove() hook.remove()
# convert btw standard 4-bit compression format and ipex compression format
def _reverse_4bit_compress_format(weight: torch.Tensor):
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out
class OutlierTracer: class OutlierTracer:
_instance = None _instance = None
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
#if BUILD_MPS #if BUILD_MPS
// #include <mps_ops.h> // #include <mps_ops.h>
#endif #endif
#if BUILD_XPU
#include <xpu_ops.h>
#endif
#include <cpu_ops.h> #include <cpu_ops.h>
// Compatibility between HIP/CUDA APIs // Compatibility between HIP/CUDA APIs
...@@ -308,6 +311,90 @@ void spmm_coo_very_sparse_naive_int8( ...@@ -308,6 +311,90 @@ void spmm_coo_very_sparse_naive_int8(
} }
#endif #endif
#if BUILD_XPU
void dequantizeBlockwise_fp16(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream);
}
void dequantizeBlockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream);
}
void gemv_4bit_inference_fp16(
int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
int ldb, int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference<sycl::half, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void gemv_4bit_inference_bf16(
int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16>(
m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
);
}
void gemv_4bit_inference_fp32(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference<float, 32>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#endif
extern "C" { extern "C" {
#if BUILD_CUDA || BUILD_HIP #if BUILD_CUDA || BUILD_HIP
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
...@@ -658,6 +745,88 @@ void cgemm_4bit_inference_naive_fp32( ...@@ -658,6 +745,88 @@ void cgemm_4bit_inference_naive_fp32(
#endif #endif
#if BUILD_XPU
void cdequantize_blockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp16(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
) {
dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
}
void cdequantize_blockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
}
void cgemv_4bit_inference_fp16(
int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
int ldb, int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemv_4bit_inference_bf16(
int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
void cgemv_4bit_inference_fp32(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, sycl::queue* stream
) {
gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
}
#endif
void cquantize_blockwise_cpu_fp32( void cquantize_blockwise_cpu_fp32(
float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n
) { ) {
......
#include "xpu_kernels.h"
#include <bit>
#include <cmath>
#include <iostream>
#include <sycl/sycl.hpp>
inline float dDequantizeFP4(unsigned char val) {
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4)
if ((val & 0b0010) == 2)
if ((val & 0b0001) == 1)
return -0.25000000f;
else
return -0.16666667f;
else if ((val & 0b0001) == 1)
return -0.50000000f;
else
return -0.33333333f;
else if ((val & 0b0010) == 2)
if ((val & 0b0001) == 1)
return -1.00000000f;
else
return -0.66666667f;
else if ((val & 0b0001) == 1)
return -5.208333333e-03f;
else
return 0.00000000f;
else if ((val & 0b0100) == 4)
if ((val & 0b0010) == 2)
if ((val & 0b0001) == 1)
return 0.25000000f;
else
return 0.16666667f;
else if ((val & 0b0001) == 1)
return 0.50000000f;
else
return 0.33333333f;
else if ((val & 0b0010) == 2)
if ((val & 0b0001) == 1)
return 1.00000000f;
else
return 0.66666667f;
else if ((val & 0b0001) == 1)
return 5.208333333e-03f;
else
return 0.00000000f;
}
inline float dDequantizeNF4(unsigned char val) {
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if ((val & 0b1000) == 8)
if ((val & 0b0100) == 4) // 1
if ((val & 0b0010) == 2) // 11
if ((val & 0b0001) == 1) // 111
return 1.0f; //*1111
else
return 0.7229568362236023f; //*1110
else if ((val & 0b0001) == 1) // 110
return 0.5626170039176941f; //*1101
else
return 0.44070982933044434f; //*1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 0.33791524171829224f; //*1011
else
return 0.24611230194568634f; //*1010
else if ((val & 0b0001) == 1) // 100
return 0.16093020141124725f; //*1001
else
return 0.07958029955625534f; //*1000
else if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 011
return 0.0f; //*0111
else
return -0.09105003625154495f; //*0110
else if ((val & 0b0001) == 1) // 010
return -0.18477343022823334f; //*0101
else
return -0.28444138169288635f; //*0100
else if ((val & 0b0010) == 2) // 00
if ((val & 0b0001) == 1) // 001
return -0.39491748809814453f; //*0011
else
return -0.5250730514526367f; //*0010
else if ((val & 0b0001) == 1) // 000
return -0.6961928009986877f; //*0001
else
return -1.0f; //*0000
}
template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
const int base_idx = item.get_group(0) * TILE_SIZE;
size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
float local_abs_max = -FLT_MAX;
int local_load_idx = 0;
int local_store_idx = 0;
uint8_t qvals[NUM_PER_TH];
T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];
if (DATA_TYPE > 0) {
local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
} else {
local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
local_store_idx = local_load_idx;
}
// Avoid expensive division by the blocksize (as blocksize will always be a
// power-of-2)
local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero<unsigned int>(blocksize))];
if (local_idx + NUM_PER_TH < local_load_idx) {
reinterpret_cast<sycl::vec<uint8_t, NUM_PER_TH>(&)[NUM_PER_TH]>(qvals)[0] =
reinterpret_cast<sycl::vec<uint8_t, NUM_PER_TH>*>(A)[(base_idx + local_idx) / NUM_PER_TH];
} else {
#pragma unroll NUM_PER_TH
for (int i = 0; i < NUM_PER_TH; i++) {
if (local_idx + i < local_load_idx) {
qvals[i] = A[base_idx + local_idx + i];
} else {
qvals[i] = (uint8_t)0;
}
}
}
switch (DATA_TYPE) {
case General8bit:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
vals[j] = code[qvals[j]] * local_abs_max;
break;
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
}
break;
}
const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH;
int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx;
if (local_dst_idx + local_dst_size < local_store_idx) {
reinterpret_cast<sycl::vec<T, local_dst_size>*>(
out
)[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] =
reinterpret_cast<sycl::vec<T, local_dst_size>(&)[local_dst_size]>(vals)[0];
} else {
#pragma unroll NUM_PER_TH
for (int i = 0; i < local_dst_size; i++) {
if (local_dst_idx + i < local_store_idx) {
out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i];
}
}
}
}
template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS>
SYCL_EXTERNAL void
kgemv_4bit_inference<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS>::operator()(sycl::nd_item<1> item) const {
size_t idx = item.get_local_id();
const int sg_idx = idx / SUBG_SIZE;
const int sg_lane = idx % SUBG_SIZE;
const int num_values_4bit = SUBG_SIZE;
const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f;
unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit / 4];
T local_A[num_values_4bit / 4];
T local_absmax = T(0.0f);
if (idx < 16) {
quant_map[idx] = T(datatype[idx]);
}
item.barrier(sycl::access::fence_space::local_space);
for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) {
const int inner_idx_halved = inner_idx / 2;
// Avoid expensive division by the blocksize (as blocksize will always be a
// power-of-2)
const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize));
local_absmax = absmax[absidx];
if (row_B < N) {
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
reinterpret_cast<sycl::vec<int, 4>(&)[num_values_8bit]>(local_B_4bit)[0] =
reinterpret_cast<sycl::vec<int, 4>*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
if ((inner_idx_halved) + j < (K / 2))
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
}
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
}
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
if (BITS == 16) {
reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[0] =
reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 4) + i];
} else {
reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[0] =
reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
reinterpret_cast<sycl::vec<int, 4>(&)[num_values_4bit / 4]>(local_A)[1] =
reinterpret_cast<sycl::vec<int, 4>*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
}
} else {
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++)
if (inner_idx + (i * num_values_4bit / 4) + k < K)
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
else
local_A[k] = T(0.0f);
}
// accumulate in float for accuracy;
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
local_C += (float)(local_A[k] * local_B[k]);
}
}
}
local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>());
if (row_B < N && sg_lane == 0)
out[row_B] = T(local_C);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template class kDequantizeBlockwise<sycl::half, 512, 4, FP4>;
template class kDequantizeBlockwise<sycl::half, 512, 4, General8bit>;
template class kDequantizeBlockwise<sycl::half, 512, 4, NF4>;
template class kDequantizeBlockwise<float, 512, 4, FP4>;
template class kDequantizeBlockwise<float, 512, 4, General8bit>;
template class kDequantizeBlockwise<float, 512, 4, NF4>;
template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, FP4>;
template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, General8bit>;
template class kDequantizeBlockwise<sycl::ext::oneapi::bfloat16, 512, 4, NF4>;
template class kgemv_4bit_inference<sycl::half, 128, 4, 32, 16>;
template class kgemv_4bit_inference<sycl::ext::oneapi::bfloat16, 128, 4, 32, 16>;
template class kgemv_4bit_inference<float, 128, 4, 32, 32>;
#include <float.h>
#include <xpu_ops.h>
#ifndef xpu_kernels
#define xpu_kernels
template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE> class kDequantizeBlockwise {
public:
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_)
: code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {}
private:
float* code;
uint8_t* A;
float* absmax;
T* out;
const int blocksize;
const int n;
};
template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS> class kgemv_4bit_inference {
public:
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
kgemv_4bit_inference(
int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_,
int ldb_, int ldc_, int blocksize_
)
: M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_),
ldc(ldc_), blocksize(blocksize_), quant_map() {}
void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor<T>(16, cgh); }
private:
int M;
int N;
int K;
T* A;
unsigned char* B;
float* absmax;
const float* datatype;
T* out;
int lda;
int ldb;
int ldc;
int blocksize;
sycl::local_accessor<T> quant_map;
};
#endif
#include <common.h>
#include <xpu_kernels.h>
#include <xpu_ops.h>
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream
) {
auto& queue = *stream;
const int workgroup_size = 128;
const int num_per_th = 4;
const int tile_size = workgroup_size * num_per_th;
if (DATA_TYPE > 0) {
const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2);
sycl::range<1> local_range{(size_t)workgroup_size};
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize / 2, n);
sycl_kernel_submit<decltype(kfn), 1, 32>(
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
);
} else {
const int workgroup_num = (n + tile_size - 1) / tile_size;
sycl::range<1> local_range{(size_t)workgroup_size};
sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
kDequantizeBlockwise<T, tile_size, num_per_th, DATA_TYPE> kfn(code, A, absmax, out, blocksize, n);
sycl_kernel_submit<decltype(kfn), 1, 32>(
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
);
}
}
template <typename T, int BITS>
void gemv_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, sycl::queue* stream
) {
auto& queue = *stream;
const size_t GROUP_SIZE = 128; // workgroup_size
const size_t SUBG_SIZE = 32; // subgroup_size
const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE;
size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD;
kgemv_4bit_inference<T, GROUP_SIZE, NUM_PER_THREAD, SUBG_SIZE, BITS> kfn(
m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize
);
sycl_comp_kernel_submit<decltype(kfn), 1, SUBG_SIZE>(
sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn
);
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void dequantizeBlockwise<float, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<float, FP4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<float, NF4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::half, General8bit>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::half, FP4>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::half, NF4>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, General8bit>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
);
template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, FP4>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
);
template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, NF4>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
sycl::queue* stream
);
template void gemv_4bit_inference<sycl::half, 16>(
int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
int ldb, int ldc, int blocksize, sycl::queue* stream
);
template void gemv_4bit_inference<sycl::ext::oneapi::bfloat16, 16>(
int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
);
template void gemv_4bit_inference<float, 32>(
int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
int ldc, int blocksize, sycl::queue* stream
);
#ifndef xpu_ops_H
#define xpu_ops_H
#include <assert.h>
#include <cstdint>
#include <iostream>
#include <stdio.h>
#include <functional>
#include <vector>
#include <sycl/sycl.hpp>
template <typename ker_t, int dim, int subgroup_size>
static inline void sycl_kernel_submit(sycl::nd_range<dim> range, sycl::queue q, ker_t ker) {
auto cgf = [&](::sycl::handler& cgh)
[[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for<ker_t>(range, ker); };
q.submit(cgf);
}
template <typename ker_t, int dim, int subgroup_size>
static inline void sycl_comp_kernel_submit(sycl::nd_range<dim> range, sycl::queue q, ker_t ker) {
auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] {
ker.sycl_ker_local_memory_creation(cgh);
cgh.parallel_for<ker_t>(range, ker);
};
q.submit(cgf);
}
typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream
);
template <typename T, int BITS>
void gemv_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, sycl::queue* stream
);
#endif
...@@ -138,8 +138,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d ...@@ -138,8 +138,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d
| **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** | | **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** |
|-------------|------------------------|---------------------------|-------------------------|------------| |-------------|------------------------|---------------------------|-------------------------|------------|
| **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha | | **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha |
| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | | **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha |
| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | | **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental |
| **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental | | **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental |
For each supported backend, follow the respective instructions below: For each supported backend, follow the respective instructions below:
...@@ -179,7 +179,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/ ...@@ -179,7 +179,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/
<hfoption id="Intel XPU"> <hfoption id="Intel XPU">
* A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance. * A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance.
* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements.
</hfoption> </hfoption>
</hfoptions> </hfoptions>
...@@ -235,27 +234,18 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise ...@@ -235,27 +234,18 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
</hfoption> </hfoption>
<hfoption id="Intel CPU + GPU"> <hfoption id="Intel CPU + GPU">
#### Intel CPU + XPU #### Intel CPU + GPU(XPU)
CPU needs to build CPU C++ codes, while XPU needs to build sycl codes.
If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance. Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu.
CPU: `pip install intel_extension_for_pytorch`
XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/`
Install bitsandbytes:
CPU: Need to build CPU C++ codes
``` ```
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
cmake -DCOMPUTE_BACKEND=cpu -S . cmake -DCOMPUTE_BACKEND=$bnb_device -S .
make make
pip install . pip install -e .
```
XPU:
```
pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git
``` ```
</hfoption> </hfoption>
<hfoption id="Ascend NPU"> <hfoption id="Ascend NPU">
......
...@@ -143,11 +143,11 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -143,11 +143,11 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = sum(diffs) / len(diffs) abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs) relerr = sum(reldiffs) / len(reldiffs)
if signed: if signed:
threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 threshold_abserr = 0.0035
assert abserr < 0.0036 assert abserr < 0.0036
assert relerr < 0.015 assert relerr < 0.015
else: else:
assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 assert abserr < 0.0023
assert relerr < 0.012 assert relerr < 0.012
assert A2.dtype == dtype assert A2.dtype == dtype
...@@ -178,8 +178,8 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -178,8 +178,8 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
def test_few_bit_quant(self, device, bits, method): def test_few_bit_quant(self, device, bits, method):
if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)): if bits != 8 and device == "cpu":
pytest.skip("CPU/XPU implementation only supports 8 bits") pytest.skip("CPU implementation only supports 8 bits")
abserrs = [] abserrs = []
relerrs = [] relerrs = []
...@@ -1274,8 +1274,8 @@ class TestQuantize4BitFunctional: ...@@ -1274,8 +1274,8 @@ class TestQuantize4BitFunctional:
max_errs3 = [] max_errs3 = []
# Large number of iterations is excessive and slow on CPU. # Large number of iterations is excessive and slow on CPU.
# Keep for CUDA for now. # Keep for CUDA/XPU for now.
iters = 100 if device == "cuda" else 10 iters = 10 if device == "cpu" else 100
for i in range(iters): for i in range(iters):
if kind == "fc1": if kind == "fc1":
...@@ -1377,13 +1377,13 @@ class TestQuantize4BitFunctional: ...@@ -1377,13 +1377,13 @@ class TestQuantize4BitFunctional:
assert err1 < 6e-5 assert err1 < 6e-5
assert relerr1 < 2e-4 assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995 assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995 assert relratio < 1.005 and relratio > 0.992
assert maxratio < 1.005 and maxratio > 0.995 assert maxratio < 1.005 and maxratio > 0.992
elif dtype == torch.float32: elif dtype == torch.float32:
if dim <= 512: if dim <= 512:
assert err1 < 5e-8 assert err1 < 5e-8
assert relerr1 < 1e-6 assert relerr1 < 1e-6
assert maxerr1 < 1e-7 assert maxerr1 < 1.05e-7
else: else:
assert err1 < 5e-8 assert err1 < 5e-8
assert relerr1 < 8e-6 assert relerr1 < 8e-6
...@@ -1393,16 +1393,17 @@ class TestQuantize4BitFunctional: ...@@ -1393,16 +1393,17 @@ class TestQuantize4BitFunctional:
assert maxratio < 1.005 and maxratio > 0.995 assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16: elif dtype == torch.bfloat16:
if dim <= 512: if dim <= 512:
relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007
assert err1 < 6e-4 assert err1 < 6e-4
assert relerr1 < 0.007 assert relerr1 < relerr_thres
assert maxerr1 < 0.015 assert maxerr1 < 0.015
else: else:
assert err1 < 2e-4 assert err1 < 2e-4
assert relerr1 < 0.002 assert relerr1 < 0.002
assert maxerr1 < 0.0012 assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995 assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.04 and relratio > 0.96 assert relratio < 1.05 and relratio > 0.96
assert maxratio < 1.02 and maxratio > 0.98 assert maxratio < 1.05 and maxratio > 0.97
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment