Unverified Commit 945f7c1d authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix CI regression (#1666)

* Tests: xfail opcheck for 4bit quantization with floating storage dtypes

* Tests: xfail opcheck for 4bit quantization with floating storage dtypes

* Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch

* Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch
parent a2a74ede
...@@ -34,7 +34,7 @@ supported_torch_devices = { ...@@ -34,7 +34,7 @@ supported_torch_devices = {
if torch.cuda.is_available(): if torch.cuda.is_available():
from .backends.cuda import ops as cuda_ops from .backends.cuda import ops as cuda_ops
if torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops from .backends.xpu import ops as xpu_ops
......
...@@ -30,7 +30,9 @@ _NF4_QUANT_TABLE = torch.tensor( ...@@ -30,7 +30,9 @@ _NF4_QUANT_TABLE = torch.tensor(
1.0, 1.0,
], ],
dtype=torch.float32, dtype=torch.float32,
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. device="xpu"
if hasattr(torch, "xpu") and torch.xpu.is_available()
else "cpu", # Only cpu/xpu use this table for now.
) )
_FP4_QUANT_TABLE = torch.tensor( _FP4_QUANT_TABLE = torch.tensor(
[ [
...@@ -52,6 +54,8 @@ _FP4_QUANT_TABLE = torch.tensor( ...@@ -52,6 +54,8 @@ _FP4_QUANT_TABLE = torch.tensor(
-0.2500, -0.2500,
], ],
dtype=torch.float32, dtype=torch.float32,
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. device="xpu"
if hasattr(torch, "xpu") and torch.xpu.is_available()
else "cpu", # Only cpu/xpu use this table for now.
) )
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
...@@ -1330,6 +1330,9 @@ class TestQuantize4BitFunctional: ...@@ -1330,6 +1330,9 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
dims = 10 dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242)) torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims) dims = get_test_dims(0, 8192, n=dims)
......
...@@ -167,9 +167,8 @@ class Test4bitBlockwiseQuantOps: ...@@ -167,9 +167,8 @@ class Test4bitBlockwiseQuantOps:
assert absmax.device == A.device assert absmax.device == A.device
assert absmax.dtype == torch.float32 assert absmax.dtype == torch.float32
# TODO: Enable it if storage_dtype != torch.uint8:
if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16: pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
pytest.skip("CPU bf16 storage_dtype will fail on torch op check")
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment