Unverified Commit 1813b058 authored by Liu Xiaoli's avatar Liu Xiaoli Committed by GitHub
Browse files

Add SYCL Kernels for XPU backend (#1679)



* Add SYCL Kernels for XPU backend

* fix transpose
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix log and format
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* revert cpu changes
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean ipex_xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean ipex import
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix ipex cpu import
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix typo
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix comments
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* refine gemv_4bit kernel

* enable FP4 for dequant_4bit and gemv_4bit

* refine FP4 dequantization performance

* remove check for better performance
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix doc
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* clean code

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* rm comments
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix memory issue

* fix ut failure

* adjust threshold
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix xpu check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* change test_functional check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix test_module
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix device check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Enable Windows build and refine code

* fix xpu log
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* remove ipex entirely
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix cpu int8 CB
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix lint
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix logs (#12)

* fix logs
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix format
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Fix sycl lint error and tests (#13)

* fix sycl nd
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip typo check for xpu kernel codes (#14)

* skip test for xpu ops
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix lint
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip typo for xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* register triton kernel for quantization (#15)
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* Fix version comparison issue (#18)

# Description

The version comparison expression miss reference the .release property from the version object. This lead to compare between the tuple and the string

# Error message
```
The 8-bit optimizer is not available on your device, only available on CUDA for now.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Traceback (most recent call last):
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/unsloth_validation/run.py", line 1, in <module>
    import unsloth
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/__init__.py", line 235, in <module>
    from .models import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/__init__.py", line 15, in <module>
    from .llama     import FastLlamaModel
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/llama.py", line 23, in <module>
    from ._utils import *
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth/models/_utils.py", line 89, in <module>
    from unsloth_zoo.patching_utils import (
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/unsloth_zoo/patching_utils.py", line 629, in <module>
    import transformers.integrations.bitsandbytes
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/v/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py", line 20, in <module>
    import bitsandbytes as bnb
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/__init__.py", line 39, in <module>
    from .backends.xpu import ops as xpu_ops
  File "/home/erxin/jenkins/workspace/Unsloth_Benchmark/bitsandbytes/bitsandbytes/backends/xpu/ops.py", line 17, in <module>
    if version.parse(torch.__version__).release >= version.parse("2.9"):
TypeError: '>=' not supported between instances of 'tuple' and 'Version'
```

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatarjiqing-feng <jiqing.feng@intel.com>
Co-authored-by: default avatarEr-Xin (Edwin) Shang <shangerxin@hotmail.com>
parent 275671be
...@@ -272,14 +272,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): ...@@ -272,14 +272,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
# Test with gradients. Currently only works with threshold=0. # Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0. # Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
# There is also an issue with torch==2.7.0 on x86-64 with IPEX.
is_broken_platform = ( is_broken_platform = (
device == "cpu" device == "cpu"
and platform.system() == "Linux" and platform.system() == "Linux"
and ( and platform.machine() == "aarch64"
(platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7)) and (2, 6) <= torch.__version__ < (2, 7)
or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu)
)
) )
if threshold == 0 and not is_broken_platform: if threshold == 0 and not is_broken_platform:
......
...@@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ...@@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
...@@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ...@@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device) mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
...@@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ...@@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
...@@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ...@@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
...@@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold): ...@@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16) b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import bitsandbytes import bitsandbytes
from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.functional import ipex_xpu
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
# torch.library.opcheck is only available in torch 2.4 and later. # torch.library.opcheck is only available in torch 2.4 and later.
...@@ -145,10 +144,6 @@ class TestInt8BlockwiseQuantOps: ...@@ -145,10 +144,6 @@ class TestInt8BlockwiseQuantOps:
assert out.dtype == dtype assert out.dtype == dtype
assert out.device == A.device assert out.device == A.device
# TODO: Enable it
if device == "xpu" and ipex_xpu:
pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment