Unverified Commit aaa71d7e authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Enable CPU/XPU native and ipex path (#1628)



* enable ipex
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix cpu 8bit quantization
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix int8 and nf4 cpu inference
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* add cpu fp4 and rem
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix dequantize nf4 xpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix ipex op
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix dequantize nf4 name
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix dequantize nf4 ipex
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix matmul8bitfp
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* enable cpu tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix format
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix quantize blockwise output shape
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix quant_storage bf16 and gemv cpu
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix cpu tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix xpu tests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix lib
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip xpu dequantize blockwise op check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix matmul8bit
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* skip not used function teests
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix matmul8bit fp
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* check ipex before MatMul8bitFp
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* update ipex install guide
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* update install guide
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix error log
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix error lof
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* update comment
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* move torch op to default
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* revert ipex check
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix code tabledevice
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix code table device
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

* fix xpu ops
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>

---------
Signed-off-by: default avatarjiqing-feng <jiqing.feng@intel.com>
parent 1e54f912
......@@ -34,6 +34,9 @@ supported_torch_devices = {
if torch.cuda.is_available():
from .backends.cuda import ops as cuda_ops
if torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops
def _import_backends():
"""
......
......@@ -4,6 +4,8 @@ from typing import Optional
import torch
from .cextension import ipex_cpu, ipex_xpu
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
......@@ -327,3 +329,22 @@ def _(
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
if ipex_cpu or ipex_xpu:
# Register the dequantize_nf4_ipex implementation
torch.library.define(
"bitsandbytes::dequantize_nf4_ipex",
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_nf4_ipex")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
......@@ -8,6 +8,7 @@ import torch
from typing_extensions import deprecated
import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
......@@ -298,6 +299,63 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU MatMul8bitFp is much faster (~3x) than MatMul8bitLt in finetune.
# Because the MatMul8bitLt has more mechanisms in computing grad.
# We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow.
# We'd like to use dequant + matmul to run finetune with good performance.
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
B = state.CB
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
state.idx = False
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
ctx.A = A
ctx.dtype_bias = None if bias is None else bias.dtype
return output
@staticmethod
def backward(ctx, grad_output):
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
A = ctx.A
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
if req_gradB:
grad_B = torch.matmul(A.t(), grad_output).t()
if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
......@@ -366,6 +424,10 @@ def matmul(
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)
......@@ -378,6 +440,17 @@ def matmul_4bit(
):
assert quant_state is not None
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
......
......@@ -7,6 +7,7 @@ from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel
from ...cextension import lib
from ..utils import ipex_cpu
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
......@@ -26,9 +27,11 @@ if torch.__version__ >= (2, 6):
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
n = A.numel()
# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
......@@ -42,6 +45,24 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
return out, absmax
......@@ -50,8 +71,9 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)
lib.cdequantize_blockwise_cpu_fp32(
......@@ -62,132 +84,37 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)
return out
_NF4_QUANT_TABLE = torch.tensor(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=torch.float32,
device="cpu",
)
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
# Divide into blocks and normalize
blocks = A.reshape(-1, blocksize)
absmax = blocks.abs().max(dim=1).values.float()
scaled = blocks / absmax.unsqueeze(-1)
# Quantize with the lookup table
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
if ipex_cpu:
from bitsandbytes.utils import _reverse_4bit_compress_format
# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
def _(
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)
A = A.view(-1, 1)
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)
# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])
return blocks.to(dtype)
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# TODO: We need to determine whether `code` is NF4, FP4, or other.
# Right now we assume NF4, as this is the only one supported on CPU.
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
B,
) -> torch.Tensor:
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
return torch.ops.bitsandbytes.dequantize_4bit.default(
A,
absmax,
blocksize,
"nf4",
shape=shapeB,
dtype=A.dtype,
)
# User called gemv with B.t(), so we need to transpose it back.
# if B.shape[0] == 1:
# B_dq = B_dq.t()
return torch.nn.functional.linear(
A,
B_dq,
bias=None,
shape,
dtype,
)
from collections.abc import Sequence
from math import prod
from typing import Optional
import torch
from ..._ops import register_kernel
from ..utils import CODE
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
......@@ -142,3 +144,160 @@ def _(A: torch.Tensor, threshold=0.0):
A[outliers] = outlier_restore
return out_row, row_stats, outlier_cols
@register_kernel("bitsandbytes::quantize_blockwise", "default")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "default")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)
return out
@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
full_blocks = n // blocksize
rem = n % blocksize
blocks = full_blocks + 1 if rem else full_blocks
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_flattened = A.reshape(n)
# Scale full blocks of the tensor to [-1, 1]
A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize)
absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0]
scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1)
# Scale any partial block
if rem:
A_rem = A_flattened[-rem:]
absmax[-1] = torch.abs(A_rem).max()
scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1)
scaled = torch.cat([scaled, scaled_rem], dim=0)
# Quantize with the lookup table
code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)
# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
@register_kernel("bitsandbytes::dequantize_4bit", "default")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype).to(A.device)
out_dq = code[out_dq]
# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)
out = out.reshape(-1, *shape[1:]).to(dtype)
return out
@register_kernel("bitsandbytes::gemv_4bit", "default")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# Applied from dequantize_4bit
quant_type = "fp4" if code[1] > 0 else "nf4"
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)
return torch.nn.functional.linear(
A,
B_dq,
bias=None,
)
import torch
try:
# to support Intel CPU/XPU (IPEX) backend
import intel_extension_for_pytorch as ipex
ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None
_NF4_QUANT_TABLE = torch.tensor(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=torch.float32,
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
)
_FP4_QUANT_TABLE = torch.tensor(
[
0.0000,
0.0052,
0.6667,
1.0000,
0.3333,
0.5000,
0.1667,
0.2500,
0.0000,
-0.0052,
-0.6667,
-1.0000,
-0.3333,
-0.5000,
-0.1667,
-0.2500,
],
dtype=torch.float32,
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
)
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
from collections.abc import Sequence
import torch
from ..._ops import register_kernel
from ..utils import ipex_xpu
if torch.__version__ >= (2, 7):
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor):
return torch._int_mm(
A.reshape(-1, A.shape[-1]),
B.t(),
).reshape(*A.shape[:-1], B.shape[0])
if ipex_xpu:
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype)
@register_kernel("bitsandbytes::dequantize_blockwise", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
) -> torch.Tensor:
shape = A.shape
out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device)
# void cdequantize_blockwise_fp32(
# float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
if dtype == torch.float16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.bfloat16:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
elif dtype == torch.float32:
ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
return out.reshape(shape)
......@@ -286,11 +286,26 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll)
try:
# to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex
ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None
try:
lib = get_native_library()
except Exception as e:
error_msg = str(e)
logger.error(f"bitsandbytes library load error: {error_msg}\n", exc_info=True)
if not (ipex_cpu or ipex_xpu):
logger.error(
f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops",
exc_info=True,
)
# create a mock with error messaging as fallback
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
......@@ -13,9 +13,9 @@ import torch
from torch import Tensor
from typing_extensions import deprecated
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import lib
from .cextension import ipex_cpu, ipex_xpu, lib
name2qmap = {}
......@@ -1122,6 +1122,16 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()
# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
return torch.ops.bitsandbytes.dequantize_nf4_ipex(
A,
absmax,
quant_state.blocksize,
quant_state.shape,
quant_state.dtype,
)
if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
......@@ -1709,6 +1719,25 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
if getattr(state, "ipex", False) and state.quant_type == "nf4":
# compute_dtype: 1 indicates fp16, 2 indicates bf16
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
out = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
compute_dtype,
1,
state.compensation,
)
return out
if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
......@@ -2507,3 +2536,49 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return x.to(dtype)
else:
return None
def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
quant_state = linear.weight.quant_state
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")
if x.device.type == "cpu" and ipex_cpu:
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
elif x.device.type == "xpu" and ipex_xpu:
new_weight = _reverse_4bit_compress_format(linear.weight.data)
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
new_scales = list(new_scales)
if not linear.training and not x.requires_grad:
new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7"
)
linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
linear.weight.quant_state.new_zeros = new_zeros
linear.weight.quant_state.compensation = compensation
......@@ -11,11 +11,12 @@ from torch import Tensor, device, dtype, nn
import torch.nn.functional as F
import bitsandbytes as bnb
from bitsandbytes.functional import QuantState
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
_reverse_4bit_compress_format,
)
T = TypeVar("T", bound="torch.nn.Module")
......@@ -444,6 +445,7 @@ class Linear4bit(nn.Linear):
self.compute_type_is_set = False
self.quant_state = None
self.quant_storage = quant_storage
self.ipex_linear_is_set = False
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
......@@ -470,13 +472,40 @@ class Linear4bit(nn.Linear):
save weight and bias,
then fill state_dict with components of quant_state
"""
if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
if self.weight.device.type == "cpu":
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = _reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
self.weight.quant_state.ipex = False
self.ipex_linear_is_set = False
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def set_ipex_linear(self, x: torch.Tensor):
if (
not getattr(self.weight.quant_state, "ipex", False)
and self.weight.data.dtype == torch.uint8
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
):
if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
_enable_ipex_fusion(self, x)
def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu):
self.set_ipex_linear(x)
self.ipex_linear_is_set = True
fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
......@@ -492,8 +521,10 @@ class Linear4bit(nn.Linear):
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
# IPEX CPU will change weight to 4D so don't need transpose
weight = self.weight.t() if self.weight.dim() == 2 else self.weight
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
class LinearFP4(Linear4bit):
......@@ -644,8 +675,11 @@ class Int8Params(torch.nn.Parameter):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
else:
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"):
self.CB = self.data
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
......
......@@ -38,6 +38,14 @@ def outlier_hook(module, input):
hook.remove()
# convert btw standard 4-bit compression format and ipex compression format
def _reverse_4bit_compress_format(weight: torch.Tensor):
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out
class OutlierTracer:
_instance = None
......
......@@ -238,15 +238,24 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
#### Intel CPU + XPU
It does not need compile CPP codes, all required ops are in [intel_extension_for_pytorch](https://pytorch-extension.intel.com/), please follow the instruction to install ipex.
If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance.
The below commands are for Linux. For installing on Windows, please adapt the below commands according to the same pattern as described [the section above on compiling from source under the Windows tab](#cuda-compile).
CPU: `pip install intel_extension_for_pytorch`
XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/`
```bash
pip install intel_extension_for_pytorch
git clone --depth 1 -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out)
Install bitsandbytes:
CPU: Need to build CPU C++ codes
```
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
cmake -DCOMPUTE_BACKEND=cpu -S .
make
pip install .
```
XPU:
```
pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git
```
</hfoption>
<hfoption id="Ascend NPU">
......
......@@ -180,9 +180,6 @@ def test_matmul_4bit(
compress_statistics,
quant_type,
):
if device == "cpu" and quant_type == "fp4":
pytest.xfail("Only nf4 is supported on CPU")
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
......
......@@ -103,10 +103,9 @@ class Test8BitBlockwiseQuantizeFunctional:
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
pytest.skip("Only blocksize 256 is the typical one supported on CPU.")
pytest.skip("Only blocksize 256 is used in CPU/XPU")
if dtype != torch.float32:
pytest.xfail(f"CPU implementation currently only supports float32, got {dtype}")
pytest.skip("Only float32 is used in CPU/XPU")
diffs = []
reldiffs = []
......@@ -138,10 +137,11 @@ class Test8BitBlockwiseQuantizeFunctional:
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
assert abserr < 0.0035
threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035
assert abserr < 0.0036
assert relerr < 0.015
else:
assert abserr < 0.00175
assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023
assert relerr < 0.012
assert A2.dtype == dtype
......@@ -172,8 +172,8 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
def test_few_bit_quant(self, device, bits, method):
if device == "cpu" and bits != 8:
pytest.skip("CPU implementation only supports 8 bits")
if device in ("cpu", "xpu") and bits != 8:
pytest.skip("CPU/XPU implementation only supports 8 bits")
abserrs = []
relerrs = []
......@@ -1080,9 +1080,6 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
......@@ -1115,9 +1112,6 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
errs1 = []
errs2 = []
for i in range(10):
......@@ -1190,12 +1184,6 @@ class TestQuantize4BitFunctional:
)
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
if device == "cpu":
if storage_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
if quant_storage != torch.uint8:
pytest.xfail("Only uint8 storage is supported on CPU")
errs1 = []
errs2 = []
errs3 = []
......@@ -1342,13 +1330,6 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
if device == "cpu":
if storage_type != "nf4":
pytest.xfail("fp4 quantization is not supported on CPU")
if dtype == torch.bfloat16 and torch.__version__ < (2, 3):
pytest.xfail("eye doe not support bfloat16 on CPU in torch < 2.3")
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
dims = get_test_dims(0, 8192, n=dims)
......
......@@ -32,12 +32,6 @@ storage = {
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward):
if device == "cpu":
if quant_type == "fp4":
pytest.xfail("FP4 is not supported for CPU")
if quant_storage != "uint8":
pytest.xfail("Only uint8 storage is supported for CPU")
original_dtype = torch.float16
compute_dtype = None
layer_shape = (300, 400)
......@@ -194,13 +188,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
tensor = torch.randn(300, 400)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
......@@ -219,13 +207,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
tensor = torch.linspace(1, blocksize, blocksize)
tensor = torch.randn(300, 400)
param = bnb.nn.Params4bit(
data=tensor,
quant_type=quant_type,
......@@ -251,13 +233,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
@pytest.mark.parametrize("blocksize", [64, 128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "cpu":
if compress_statistics:
pytest.skip("Currently segfaults on CPU")
if quant_type == "fp4":
pytest.xfail("FP4 not supported on CPU")
original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32)
original_tensor = torch.randn(300, 400)
original_param = bnb.nn.Params4bit(
data=original_tensor,
quant_type=quant_type,
......
......@@ -391,12 +391,6 @@ def test_fp8linear():
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.xfail("FP4 is not supported for CPU")
if quant_storage is not None and quant_storage != torch.uint8:
pytest.xfail("CPU only supports uint8 storage for 4bit")
num_embeddings = 128
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
......@@ -442,12 +436,6 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim,
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage):
if device == "cpu":
if embedding_class is bnb.nn.EmbeddingFP4:
pytest.xfail("FP4 is not supported for CPU")
if quant_storage is not None and quant_storage != torch.uint8:
pytest.xfail("CPU only supports uint8 storage for 4bit")
is_8bit = embedding_class is bnb.nn.Embedding8bit
num_embeddings = 128
......@@ -482,9 +470,6 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu
@pytest.mark.parametrize("device", get_available_devices())
def test_4bit_linear_warnings(device):
if device == "cpu":
pytest.xfail("gemv_4bit op is not yet implemented on CPU")
dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"):
......
......@@ -143,6 +143,10 @@ class TestInt8BlockwiseQuantOps:
assert out.dtype == dtype
assert out.device == A.device
# TODO: Enable it
if device == "xpu":
pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
......@@ -153,15 +157,9 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu" and quant_type != "nf4":
pytest.xfail("CPU implementation is only available for nf4")
if storage_dtype != torch.uint8:
pytest.xfail("Known issue with storage_dtype != uint8")
A = torch.randn(1024, 1024, dtype=dtype, device=device)
out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype)
assert out.device == A.device
assert out.dtype == storage_dtype
......@@ -169,6 +167,10 @@ class Test4bitBlockwiseQuantOps:
assert absmax.device == A.device
assert absmax.dtype == torch.float32
# TODO: Enable it
if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16:
pytest.skip("CPU bf16 storage_dtype will fail on torch op check")
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
@pytest.mark.parametrize("device", get_available_devices())
......@@ -177,13 +179,6 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
if quant_type != "nf4":
pytest.xfail("CPU implementation is only available for nf4")
if storage_dtype != torch.uint8:
pytest.xfail("CPU implementation only supports uint8 storage")
shape = (128, 128)
n = prod(shape)
......@@ -215,9 +210,6 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "cpu":
pytest.xfail("CPU implementation is not available")
out_features = 1024
in_features = 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment